Repository: thuml/Transfer-Learning-Library
Branch: master
Commit: c4aa59eb5656
Files: 445
Total size: 2.6 MB
Directory structure:
gitextract_h0zbpfvz/
├── .github/
│ └── ISSUE_TEMPLATE/
│ ├── bug_report.md
│ ├── custom.md
│ └── feature_request.md
├── .gitignore
├── CONTRIBUTING.md
├── DATASETS.md
├── LICENSE
├── README.md
├── docs/
│ ├── Makefile
│ ├── conf.py
│ ├── index.rst
│ ├── make.bat
│ ├── requirements.txt
│ └── tllib/
│ ├── alignment/
│ │ ├── domain_adversarial.rst
│ │ ├── hypothesis_adversarial.rst
│ │ ├── index.rst
│ │ └── statistics_matching.rst
│ ├── modules.rst
│ ├── normalization.rst
│ ├── ranking.rst
│ ├── regularization.rst
│ ├── reweight.rst
│ ├── self_training.rst
│ ├── translation.rst
│ ├── utils/
│ │ ├── analysis.rst
│ │ ├── base.rst
│ │ ├── index.rst
│ │ └── metric.rst
│ └── vision/
│ ├── datasets.rst
│ ├── index.rst
│ ├── models.rst
│ └── transforms.rst
├── examples/
│ ├── domain_adaptation/
│ │ ├── image_classification/
│ │ │ ├── README.md
│ │ │ ├── adda.py
│ │ │ ├── adda.sh
│ │ │ ├── afn.py
│ │ │ ├── afn.sh
│ │ │ ├── bsp.py
│ │ │ ├── bsp.sh
│ │ │ ├── cc_loss.py
│ │ │ ├── cc_loss.sh
│ │ │ ├── cdan.py
│ │ │ ├── cdan.sh
│ │ │ ├── dan.py
│ │ │ ├── dan.sh
│ │ │ ├── dann.py
│ │ │ ├── dann.sh
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── fixmatch.py
│ │ │ ├── fixmatch.sh
│ │ │ ├── jan.py
│ │ │ ├── jan.sh
│ │ │ ├── mcc.py
│ │ │ ├── mcc.sh
│ │ │ ├── mcd.py
│ │ │ ├── mcd.sh
│ │ │ ├── mdd.py
│ │ │ ├── mdd.sh
│ │ │ ├── requirements.txt
│ │ │ ├── self_ensemble.py
│ │ │ ├── self_ensemble.sh
│ │ │ └── utils.py
│ │ ├── image_regression/
│ │ │ ├── README.md
│ │ │ ├── dann.py
│ │ │ ├── dann.sh
│ │ │ ├── dd.py
│ │ │ ├── dd.sh
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── rsd.py
│ │ │ ├── rsd.sh
│ │ │ └── utils.py
│ │ ├── keypoint_detection/
│ │ │ ├── README.md
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── regda.py
│ │ │ ├── regda.sh
│ │ │ ├── regda_fast.py
│ │ │ └── regda_fast.sh
│ │ ├── object_detection/
│ │ │ ├── README.md
│ │ │ ├── config/
│ │ │ │ ├── faster_rcnn_R_101_C4_cityscapes.yaml
│ │ │ │ ├── faster_rcnn_R_101_C4_voc.yaml
│ │ │ │ ├── faster_rcnn_vgg_16_cityscapes.yaml
│ │ │ │ └── retinanet_R_101_FPN_voc.yaml
│ │ │ ├── cycle_gan.py
│ │ │ ├── cycle_gan.sh
│ │ │ ├── d_adapt/
│ │ │ │ ├── README.md
│ │ │ │ ├── bbox_adaptation.py
│ │ │ │ ├── category_adaptation.py
│ │ │ │ ├── config/
│ │ │ │ │ ├── faster_rcnn_R_101_C4_cityscapes.yaml
│ │ │ │ │ ├── faster_rcnn_R_101_C4_voc.yaml
│ │ │ │ │ ├── faster_rcnn_vgg_16_cityscapes.yaml
│ │ │ │ │ └── retinanet_R_101_FPN_voc.yaml
│ │ │ │ ├── d_adapt.py
│ │ │ │ └── d_adapt.sh
│ │ │ ├── oracle.sh
│ │ │ ├── prepare_cityscapes_to_voc.py
│ │ │ ├── requirements.txt
│ │ │ ├── source_only.py
│ │ │ ├── source_only.sh
│ │ │ ├── utils.py
│ │ │ ├── visualize.py
│ │ │ └── visualize.sh
│ │ ├── openset_domain_adaptation/
│ │ │ ├── README.md
│ │ │ ├── dann.py
│ │ │ ├── dann.sh
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── osbp.py
│ │ │ ├── osbp.sh
│ │ │ └── utils.py
│ │ ├── partial_domain_adaptation/
│ │ │ ├── README.md
│ │ │ ├── afn.py
│ │ │ ├── afn.sh
│ │ │ ├── dann.py
│ │ │ ├── dann.sh
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── iwan.py
│ │ │ ├── iwan.sh
│ │ │ ├── pada.py
│ │ │ ├── pada.sh
│ │ │ ├── requirements.txt
│ │ │ └── utils.py
│ │ ├── re_identification/
│ │ │ ├── README.md
│ │ │ ├── baseline.py
│ │ │ ├── baseline.sh
│ │ │ ├── baseline_cluster.py
│ │ │ ├── baseline_cluster.sh
│ │ │ ├── ibn.sh
│ │ │ ├── mmt.py
│ │ │ ├── mmt.sh
│ │ │ ├── requirements.txt
│ │ │ ├── spgan.py
│ │ │ ├── spgan.sh
│ │ │ └── utils.py
│ │ ├── semantic_segmentation/
│ │ │ ├── README.md
│ │ │ ├── advent.py
│ │ │ ├── advent.sh
│ │ │ ├── cycada.py
│ │ │ ├── cycada.sh
│ │ │ ├── cycle_gan.py
│ │ │ ├── cycle_gan.sh
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── fda.py
│ │ │ └── fda.sh
│ │ ├── wilds_image_classification/
│ │ │ ├── README.md
│ │ │ ├── cdan.py
│ │ │ ├── cdan.sh
│ │ │ ├── dan.py
│ │ │ ├── dan.sh
│ │ │ ├── dann.py
│ │ │ ├── dann.sh
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── fixmatch.py
│ │ │ ├── fixmatch.sh
│ │ │ ├── jan.py
│ │ │ ├── jan.sh
│ │ │ ├── mdd.py
│ │ │ ├── mdd.sh
│ │ │ ├── requirements.txt
│ │ │ └── utils.py
│ │ ├── wilds_ogb_molpcba/
│ │ │ ├── README.md
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── gin.py
│ │ │ ├── requirements.txt
│ │ │ └── utils.py
│ │ ├── wilds_poverty/
│ │ │ ├── README.md
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── requirements.txt
│ │ │ ├── resnet_ms.py
│ │ │ └── utils.py
│ │ └── wilds_text/
│ │ ├── README.md
│ │ ├── erm.py
│ │ ├── erm.sh
│ │ ├── requirements.txt
│ │ └── utils.py
│ ├── domain_generalization/
│ │ ├── image_classification/
│ │ │ ├── README.md
│ │ │ ├── coral.py
│ │ │ ├── coral.sh
│ │ │ ├── erm.py
│ │ │ ├── erm.sh
│ │ │ ├── groupdro.py
│ │ │ ├── groupdro.sh
│ │ │ ├── ibn.sh
│ │ │ ├── irm.py
│ │ │ ├── irm.sh
│ │ │ ├── mixstyle.py
│ │ │ ├── mixstyle.sh
│ │ │ ├── mldg.py
│ │ │ ├── mldg.sh
│ │ │ ├── requirements.txt
│ │ │ ├── utils.py
│ │ │ ├── vrex.py
│ │ │ └── vrex.sh
│ │ └── re_identification/
│ │ ├── README.md
│ │ ├── baseline.py
│ │ ├── baseline.sh
│ │ ├── ibn.sh
│ │ ├── mixstyle.py
│ │ ├── mixstyle.sh
│ │ ├── requirements.txt
│ │ └── utils.py
│ ├── model_selection/
│ │ ├── README.md
│ │ ├── hscore.py
│ │ ├── hscore.sh
│ │ ├── leep.py
│ │ ├── leep.sh
│ │ ├── logme.py
│ │ ├── logme.sh
│ │ ├── nce.py
│ │ ├── nce.sh
│ │ ├── requirements.txt
│ │ └── utils.py
│ ├── semi_supervised_learning/
│ │ └── image_classification/
│ │ ├── README.md
│ │ ├── convert_moco_to_pretrained.py
│ │ ├── debiasmatch.py
│ │ ├── debiasmatch.sh
│ │ ├── dst.py
│ │ ├── dst.sh
│ │ ├── erm.py
│ │ ├── erm.sh
│ │ ├── fixmatch.py
│ │ ├── fixmatch.sh
│ │ ├── flexmatch.py
│ │ ├── flexmatch.sh
│ │ ├── mean_teacher.py
│ │ ├── mean_teacher.sh
│ │ ├── noisy_student.py
│ │ ├── noisy_student.sh
│ │ ├── pi_model.py
│ │ ├── pi_model.sh
│ │ ├── pseudo_label.py
│ │ ├── pseudo_label.sh
│ │ ├── requirements.txt
│ │ ├── self_tuning.py
│ │ ├── self_tuning.sh
│ │ ├── uda.py
│ │ ├── uda.sh
│ │ └── utils.py
│ └── task_adaptation/
│ └── image_classification/
│ ├── README.md
│ ├── bi_tuning.py
│ ├── bi_tuning.sh
│ ├── bss.py
│ ├── bss.sh
│ ├── co_tuning.py
│ ├── co_tuning.sh
│ ├── convert_moco_to_pretrained.py
│ ├── delta.py
│ ├── delta.sh
│ ├── erm.py
│ ├── erm.sh
│ ├── lwf.py
│ ├── lwf.sh
│ ├── requirements.txt
│ ├── stochnorm.py
│ ├── stochnorm.sh
│ └── utils.py
├── requirements.txt
├── setup.py
└── tllib/
├── __init__.py
├── alignment/
│ ├── __init__.py
│ ├── adda.py
│ ├── advent.py
│ ├── bsp.py
│ ├── cdan.py
│ ├── coral.py
│ ├── d_adapt/
│ │ ├── __init__.py
│ │ ├── feedback.py
│ │ ├── modeling/
│ │ │ ├── __init__.py
│ │ │ ├── matcher.py
│ │ │ ├── meta_arch/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── rcnn.py
│ │ │ │ └── retinanet.py
│ │ │ └── roi_heads/
│ │ │ ├── __init__.py
│ │ │ ├── fast_rcnn.py
│ │ │ └── roi_heads.py
│ │ └── proposal.py
│ ├── dan.py
│ ├── dann.py
│ ├── jan.py
│ ├── mcd.py
│ ├── mdd.py
│ ├── osbp.py
│ ├── regda.py
│ └── rsd.py
├── modules/
│ ├── __init__.py
│ ├── classifier.py
│ ├── domain_discriminator.py
│ ├── entropy.py
│ ├── gl.py
│ ├── grl.py
│ ├── kernels.py
│ ├── loss.py
│ └── regressor.py
├── normalization/
│ ├── __init__.py
│ ├── afn.py
│ ├── ibn.py
│ ├── mixstyle/
│ │ ├── __init__.py
│ │ ├── resnet.py
│ │ └── sampler.py
│ └── stochnorm.py
├── ranking/
│ ├── __init__.py
│ ├── hscore.py
│ ├── leep.py
│ ├── logme.py
│ ├── nce.py
│ └── transrate.py
├── regularization/
│ ├── __init__.py
│ ├── bi_tuning.py
│ ├── bss.py
│ ├── co_tuning.py
│ ├── delta.py
│ ├── knowledge_distillation.py
│ └── lwf.py
├── reweight/
│ ├── __init__.py
│ ├── groupdro.py
│ ├── iwan.py
│ └── pada.py
├── self_training/
│ ├── __init__.py
│ ├── cc_loss.py
│ ├── dst.py
│ ├── flexmatch.py
│ ├── mcc.py
│ ├── mean_teacher.py
│ ├── pi_model.py
│ ├── pseudo_label.py
│ ├── self_ensemble.py
│ ├── self_tuning.py
│ └── uda.py
├── translation/
│ ├── __init__.py
│ ├── cycada.py
│ ├── cyclegan/
│ │ ├── __init__.py
│ │ ├── discriminator.py
│ │ ├── generator.py
│ │ ├── loss.py
│ │ ├── transform.py
│ │ └── util.py
│ ├── fourier_transform.py
│ └── spgan/
│ ├── __init__.py
│ ├── loss.py
│ └── siamese.py
├── utils/
│ ├── __init__.py
│ ├── analysis/
│ │ ├── __init__.py
│ │ ├── a_distance.py
│ │ └── tsne.py
│ ├── data.py
│ ├── logger.py
│ ├── meter.py
│ ├── metric/
│ │ ├── __init__.py
│ │ ├── keypoint_detection.py
│ │ └── reid.py
│ └── scheduler.py
└── vision/
├── __init__.py
├── datasets/
│ ├── __init__.py
│ ├── _util.py
│ ├── aircrafts.py
│ ├── caltech101.py
│ ├── cifar.py
│ ├── coco70.py
│ ├── cub200.py
│ ├── digits.py
│ ├── domainnet.py
│ ├── dtd.py
│ ├── eurosat.py
│ ├── food101.py
│ ├── imagelist.py
│ ├── imagenet_r.py
│ ├── imagenet_sketch.py
│ ├── keypoint_detection/
│ │ ├── __init__.py
│ │ ├── freihand.py
│ │ ├── hand_3d_studio.py
│ │ ├── human36m.py
│ │ ├── keypoint_dataset.py
│ │ ├── lsp.py
│ │ ├── rendered_hand_pose.py
│ │ ├── surreal.py
│ │ └── util.py
│ ├── object_detection/
│ │ └── __init__.py
│ ├── office31.py
│ ├── officecaltech.py
│ ├── officehome.py
│ ├── openset/
│ │ └── __init__.py
│ ├── oxfordflowers.py
│ ├── oxfordpets.py
│ ├── pacs.py
│ ├── partial/
│ │ ├── __init__.py
│ │ ├── caltech_imagenet.py
│ │ └── imagenet_caltech.py
│ ├── patchcamelyon.py
│ ├── regression/
│ │ ├── __init__.py
│ │ ├── dsprites.py
│ │ ├── image_regression.py
│ │ └── mpi3d.py
│ ├── reid/
│ │ ├── __init__.py
│ │ ├── basedataset.py
│ │ ├── convert.py
│ │ ├── dukemtmc.py
│ │ ├── market1501.py
│ │ ├── msmt17.py
│ │ ├── personx.py
│ │ └── unreal.py
│ ├── resisc45.py
│ ├── retinopathy.py
│ ├── segmentation/
│ │ ├── __init__.py
│ │ ├── cityscapes.py
│ │ ├── gta5.py
│ │ ├── segmentation_list.py
│ │ └── synthia.py
│ ├── stanford_cars.py
│ ├── stanford_dogs.py
│ ├── sun397.py
│ └── visda2017.py
├── models/
│ ├── __init__.py
│ ├── digits.py
│ ├── keypoint_detection/
│ │ ├── __init__.py
│ │ ├── loss.py
│ │ └── pose_resnet.py
│ ├── object_detection/
│ │ ├── __init__.py
│ │ ├── backbone/
│ │ │ ├── __init__.py
│ │ │ ├── mmdetection/
│ │ │ │ ├── vgg.py
│ │ │ │ └── weight_init.py
│ │ │ └── vgg.py
│ │ ├── meta_arch/
│ │ │ ├── __init__.py
│ │ │ ├── rcnn.py
│ │ │ └── retinanet.py
│ │ ├── proposal_generator/
│ │ │ ├── __init__.py
│ │ │ └── rpn.py
│ │ └── roi_heads/
│ │ ├── __init__.py
│ │ └── roi_heads.py
│ ├── reid/
│ │ ├── __init__.py
│ │ ├── identifier.py
│ │ ├── loss.py
│ │ └── resnet.py
│ ├── resnet.py
│ └── segmentation/
│ ├── __init__.py
│ └── deeplabv2.py
└── transforms/
├── __init__.py
├── keypoint_detection.py
└── segmentation.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. iOS]
- Browser [e.g. chrome, safari]
- Version [e.g. 22]
**Smartphone (please complete the following information):**
- Device: [e.g. iPhone6]
- OS: [e.g. iOS8.1]
- Browser [e.g. stock browser, safari]
- Version [e.g. 22]
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/ISSUE_TEMPLATE/custom.md
================================================
---
name: Custom issue template
about: Describe this issue template's purpose here.
title: ''
labels: ''
assignees: ''
---
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
docs/build/*
docs/pytorch_sphinx_theme/*
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.idea/
exp/*
trash/*
examples/domain_adaptation/digits/logs/*
examples/domain_adaptation/digits/data/*
+.DS_Store
*/.DS_Store
================================================
FILE: CONTRIBUTING.md
================================================
## Contributing to Transfer-Learning-Library
All kinds of contributions are welcome, including but not limited to the following.
- Fix typo or bugs
- Add documentation
- Add new features and components
### Workflow
1. fork and pull the latest Transfer-Learning-Library repository
2. checkout a new branch (do not use master branch for PRs)
3. commit your changes
4. create a PR
```{note}
If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first.
```
================================================
FILE: DATASETS.md
================================================
## Notice (2023-08-01)
### Transfer-Learning-Library Dataset Link Failure Issue
Dear users,
We sincerely apologize to inform you that the dataset links of Transfer-Learning-Library have become invalid due to a cloud storage failure, resulting in many users being unable to download the datasets properly recently.
We are working diligently to resolve this issue and plan to restore the links as soon as possible. Currently, we have already restored some dataset links, and they have been updated on the master branch. You can obtain the latest version by running "git pull."
As the version on PyPI has not been updated yet, please temporarily uninstall the old version by running "pip uninstall tllib" before installing the new one.
In the future, we are planning to store the datasets on both Baidu Cloud and Google Cloud to provide more stable download links.
Additionally, a small portion of datasets that were backed up on our local server have also been lost due to a hard disk failure. For these datasets, we need to re-download and verify them, which might take longer to restore the links.
Within this week, we will release the updated dataset and confirm the list of datasets without backup. For datasets without backup, if you have previously downloaded them locally, please contact us via email. Your support is highly appreciated.
Once again, we apologize for any inconvenience caused and thank you for your understanding.
Sincerely,
The Transfer-Learning-Library Team
## Update (2023-08-09)
Most of the dataset links have been restored at present. The confirmed datasets without backups are as follows:
- Classification
- COCO70
- EuroSAT
- PACS
- PatchCamelyon
- [Partial Domain Adaptation]
- CaltechImageNet
- Keypoint Detection
- Hand3DStudio
- LSP
- SURREAL
- Object Detection
- Comic
- Re-Identification
- PersonX
- UnrealPerson
**For these datasets, if you had previously downloaded them locally, please contact us via email. We greatly appreciate everyone's support.**
## Notice (2023-08-01)
### Transfer-Learning-Library数据集链接失效问题
各位使用者,我们很抱歉通知大家,最近Transfer-Learning-Library的数据集链接因为云盘故障而失效,导致很多使用者无法正常下载数据集。
我们正在全力以赴解决这一问题,并计划在最短的时间内恢复链接。目前我们已经恢复了部分数据集链接,更新在master分支上,您可以通过git pull来获取最新的版本。
由于pypi上的版本还未更新,暂时请首先通过pip uninstall tllib卸载旧版本。
日后我们计划将数据集存储在百度云和谷歌云上,提供更加稳定的下载链接。
另外,小部分数据集在我们本地服务器上的备份也由于硬盘故障而丢失,对于这些数据集我们需要重新下载并验证,可能需要更长的时间来恢复链接。
我们会在本周内发布已经更新的数据集和确认无备份的数据集列表,对于无备份的数据集,如果您之前有下载到本地,请通过邮件联系我们,非常感谢大家的支持。
再次向您表达我们的歉意,并感谢您的理解。
Transfer-Learning-Library团队
## Update (2023-08-09)
目前大部分数据集的链接已经恢复,确认无备份的数据集如下:
- Classification
- COCO70
- EuroSAT
- PACS
- PatchCamelyon
- [Partial Domain Adaptation]
- CaltechImageNet
- Keypoint Detection
- Hand3DStudio
- LSP
- SURREAL
- Object Detection
- Comic
- Re-Identification
- PersonX
- UnrealPerson
**对于这些数据集,如果您之前有下载到本地,请通过邮件联系我们,非常感谢大家的支持。**
================================================
FILE: LICENSE
================================================
Copyright (c) 2018 The Python Packaging Authority
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Transfer Learning Library
- [Introduction](#introduction)
- [Updates](#updates)
- [Supported Methods](#supported-methods)
- [Installation](#installation)
- [Documentation](#documentation)
- [Contact](#contact)
- [Citation](#citation)
## Update (2024-03-15)
We upload an offline version of documentation [here](/docs/html.zip). You can download and unzip it to view the documentation.
## Notice (2023-08-09)
A note on broken dataset links can be found here: [DATASETS.md](DATASETS.md).
## Introduction
*TLlib* is an open-source and well-documented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consistent with torchvision. You can easily develop new algorithms, or readily apply existing algorithms.
Our _API_ is divided by methods, which include:
- domain alignment methods (tllib.aligment)
- domain translation methods (tllib.translation)
- self-training methods (tllib.self_training)
- regularization methods (tllib.regularization)
- data reweighting/resampling methods (tllib.reweight)
- model ranking/selection methods (tllib.ranking)
- normalization-based methods (tllib.normalization)
We provide many example codes in the directory _examples_, which is divided by learning setups. Currently, the supported learning setups include:
- DA (domain adaptation)
- TA (task adaptation, also known as finetune)
- OOD (out-of-distribution generalization, also known as DG / domain generalization)
- SSL (semi-supervised learning)
- Model Selection
Our supported tasks include: classification, regression, object detection, segmentation, keypoint detection, and so on.
## Updates
### 2022.9
We support installing *TLlib* via `pip`, which is experimental currently.
```shell
pip install -i https://test.pypi.org/simple/ tllib==0.4
```
### 2022.8
We release `v0.4` of *TLlib*. Previous versions of *TLlib* can be found [here](https://github.com/thuml/Transfer-Learning-Library/releases). In `v0.4`, we add implementations of
the following methods:
- Domain Adaptation for Object Detection [[Code]](/examples/domain_adaptation/object_detection) [[API]](/tllib/alignment/d_adapt)
- Pre-trained Model Selection [[Code]](/examples/model_selection) [[API]](/tllib/ranking)
- Semi-supervised Learning for Classification [[Code]](/examples/semi_supervised_learning/image_classification/) [[API]](/tllib/self_training)
Besides, we maintain a collection of **_awesome papers in Transfer Learning_** in another repo [_A Roadmap for Transfer Learning_](https://github.com/thuml/A-Roadmap-for-Transfer-Learning).
### 2022.2
We adjusted our API following our survey [Transferablity in Deep Learning](https://arxiv.org/abs/2201.05867).
## Supported Methods
The currently supported algorithms include:
##### Domain Adaptation for Classification [[Code]](/examples/domain_adaptation/image_classification)
- **DANN** - Unsupervised Domain Adaptation by Backpropagation [[ICML 2015]](http://proceedings.mlr.press/v37/ganin15.pdf) [[Code]](/examples/domain_adaptation/image_classification/dann.py)
- **DAN** - Learning Transferable Features with Deep Adaptation Networks [[ICML 2015]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/deep-adaptation-networks-icml15.pdf) [[Code]](/examples/domain_adaptation/image_classification/dan.py)
- **JAN** - Deep Transfer Learning with Joint Adaptation Networks [[ICML 2017]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/joint-adaptation-networks-icml17.pdf) [[Code]](/examples/domain_adaptation/image_classification/jan.py)
- **ADDA** - Adversarial Discriminative Domain Adaptation [[CVPR 2017]](http://openaccess.thecvf.com/content_cvpr_2017/papers/Tzeng_Adversarial_Discriminative_Domain_CVPR_2017_paper.pdf) [[Code]](/examples/domain_adaptation/image_classification/adda.py)
- **CDAN** - Conditional Adversarial Domain Adaptation [[NIPS 2018]](http://papers.nips.cc/paper/7436-conditional-adversarial-domain-adaptation) [[Code]](/examples/domain_adaptation/image_classification/cdan.py)
- **MCD** - Maximum Classifier Discrepancy for Unsupervised Domain Adaptation [[CVPR 2018]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Saito_Maximum_Classifier_Discrepancy_CVPR_2018_paper.pdf) [[Code]](/examples/domain_adaptation/image_classification/mcd.py)
- **MDD** - Bridging Theory and Algorithm for Domain Adaptation [[ICML 2019]](http://proceedings.mlr.press/v97/zhang19i/zhang19i.pdf) [[Code]](/examples/domain_adaptation/image_classification/mdd.py)
- **BSP** - Transferability vs. Discriminability: Batch Spectral Penalization for Adversarial Domain Adaptation [[ICML 2019]](http://proceedings.mlr.press/v97/chen19i/chen19i.pdf) [[Code]](/examples/domain_adaptation/image_classification/bsp.py)
- **MCC** - Minimum Class Confusion for Versatile Domain Adaptation [[ECCV 2020]](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123660460.pdf) [[Code]](/examples/domain_adaptation/image_classification/mcc.py)
##### Domain Adaptation for Object Detection [[Code]](/examples/domain_adaptation/object_detection)
- **CycleGAN** - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks [[ICCV 2017]](https://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf) [[Code]](/examples/domain_adaptation/object_detection/cycle_gan.py)
- **D-adapt** - Decoupled Adaptation for Cross-Domain Object Detection [[ICLR 2022]](https://openreview.net/pdf?id=VNqaB1g9393) [[Code]](/examples/domain_adaptation/object_detection/d_adapt)
##### Domain Adaptation for Semantic Segmentation [[Code]](/examples/domain_adaptation/semantic_segmentation/)
- **CycleGAN** - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks [[ICCV 2017]](https://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf) [[Code]](/examples/domain_adaptation/semantic_segmentation/cycle_gan.py)
- **CyCADA** - Cycle-Consistent Adversarial Domain Adaptation [[ICML 2018]](http://proceedings.mlr.press/v80/hoffman18a.html) [[Code]](/examples/domain_adaptation/semantic_segmentation/cycada.py)
- **ADVENT** - Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation [[CVPR 2019]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Vu_ADVENT_Adversarial_Entropy_Minimization_for_Domain_Adaptation_in_Semantic_Segmentation_CVPR_2019_paper.pdf) [[Code]](/examples/domain_adaptation/semantic_segmentation/advent.py)
- **FDA** - Fourier Domain Adaptation for Semantic Segmentation [[CVPR 2020]](https://arxiv.org/abs/2004.05498) [[Code]](/examples/domain_adaptation/semantic_segmentation/fda.py)
##### Domain Adaptation for Keypoint Detection [[Code]](/examples/domain_adaptation/keypoint_detection)
- **RegDA** - Regressive Domain Adaptation for Unsupervised Keypoint Detection [[CVPR 2021]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/regressive-domain-adaptation-cvpr21.pdf) [[Code]](/examples/domain_adaptation/keypoint_detection)
##### Domain Adaptation for Person Re-identification [[Code]](/examples/domain_adaptation/re_identification/)
- **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)
- **MMT** - Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification [[ICLR 2020]](https://arxiv.org/abs/2001.01526) [[Code]](/examples/domain_adaptation/re_identification/mmt.py)
- **SPGAN** - Similarity Preserving Generative Adversarial Network [[CVPR 2018]](https://arxiv.org/pdf/1811.10551.pdf) [[Code]](/examples/domain_adaptation/re_identification/spgan.py)
##### Partial Domain Adaptation [[Code]](/examples/domain_adaptation/partial_domain_adaptation)
- **IWAN** - Importance Weighted Adversarial Nets for Partial Domain Adaptation[[CVPR 2018]](https://arxiv.org/abs/1803.09210) [[Code]](/examples/domain_adaptation/partial_domain_adaptation/iwan.py)
- **AFN** - Larger Norm More Transferable: An Adaptive Feature Norm Approach for
Unsupervised Domain Adaptation [[ICCV 2019]](https://arxiv.org/pdf/1811.07456v2.pdf) [[Code]](/examples/domain_adaptation/partial_domain_adaptation/afn.py)
##### Open-set Domain Adaptation [[Code]](/examples/domain_adaptation/openset_domain_adaptation)
- **OSBP** - Open Set Domain Adaptation by Backpropagation [[ECCV 2018]](https://arxiv.org/abs/1804.10427) [[Code]](/examples/domain_adaptation/openset_domain_adaptation/osbp.py)
##### Domain Generalization for Classification [[Code]](/examples/domain_generalization/image_classification/)
- **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)
- **MixStyle** - Domain Generalization with MixStyle [[ICLR 2021]](https://arxiv.org/abs/2104.02008) [[Code]](/examples/domain_generalization/image_classification/mixstyle.py)
- **MLDG** - Learning to Generalize: Meta-Learning for Domain Generalization [[AAAI 2018]](https://arxiv.org/pdf/1710.03463.pdf) [[Code]](/examples/domain_generalization/image_classification/mldg.py)
- **IRM** - Invariant Risk Minimization [[ArXiv]](https://arxiv.org/abs/1907.02893) [[Code]](/examples/domain_generalization/image_classification/irm.py)
- **VREx** - Out-of-Distribution Generalization via Risk Extrapolation [[ICML 2021]](https://arxiv.org/abs/2003.00688) [[Code]](/examples/domain_generalization/image_classification/vrex.py)
- **GroupDRO** - Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization [[ArXiv]](https://arxiv.org/abs/1911.08731) [[Code]](/examples/domain_generalization/image_classification/groupdro.py)
- **Deep CORAL** - Correlation Alignment for Deep Domain Adaptation [[ECCV 2016]](https://arxiv.org/abs/1607.01719) [[Code]](/examples/domain_generalization/image_classification/coral.py)
##### Domain Generalization for Person Re-identification [[Code]](/examples/domain_generalization/re_identification/)
- **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)
- **MixStyle** - Domain Generalization with MixStyle [[ICLR 2021]](https://arxiv.org/abs/2104.02008) [[Code]](/examples/domain_generalization/re_identification/mixstyle.py)
##### Task Adaptation (Fine-Tuning) for Image Classification [[Code]](/examples/task_adaptation/image_classification/)
- **L2-SP** - Explicit inductive bias for transfer learning with convolutional networks [[ICML 2018]]((https://arxiv.org/abs/1802.01483)) [[Code]](/examples/task_adaptation/image_classification/delta.py)
- **BSS** - Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning [[NIPS 2019]](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf) [[Code]](/examples/task_adaptation/image_classification/bss.py)
- **DELTA** - DEep Learning Transfer using Fea- ture Map with Attention for convolutional networks [[ICLR 2019]](https://openreview.net/pdf?id=rkgbwsAcYm) [[Code]](/examples/task_adaptation/image_classification/delta.py)
- **Co-Tuning** - Co-Tuning for Transfer Learning [[NIPS 2020]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/co-tuning-for-transfer-learning-nips20.pdf) [[Code]](/examples/task_adaptation/image_classification/co_tuning.py)
- **StochNorm** - Stochastic Normalization [[NIPS 2020]](https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf) [[Code]](/examples/task_adaptation/image_classification/stochnorm.py)
- **LWF** - Learning Without Forgetting [[ECCV 2016]](https://arxiv.org/abs/1606.09282) [[Code]](/examples/task_adaptation/image_classification/lwf.py)
- **Bi-Tuning** - Bi-tuning of Pre-trained Representations [[ArXiv]](https://arxiv.org/abs/2011.06182?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29) [[Code]](/examples/task_adaptation/image_classification/bi_tuning.py)
##### Pre-trained Model Selection [[Code]](/examples/model_selection)
- **H-Score** - An Information-theoretic Approach to Transferability in Task Transfer Learning [[ICIP 2019]](http://yangli-feasibility.com/home/media/icip-19.pdf) [[Code]](/examples/model_selection/hscore.py)
- **NCE** - Negative Conditional Entropy in `Transferability and Hardness of Supervised Classification Tasks [[ICCV 2019]](https://arxiv.org/pdf/1908.08142v1.pdf) [[Code]](/examples/model_selection/nce.py)
- **LEEP** - LEEP: A New Measure to Evaluate Transferability of Learned Representations [[ICML 2020]](http://proceedings.mlr.press/v119/nguyen20b/nguyen20b.pdf) [[Code]](/examples/model_selection/leep.py)
- **LogME** - Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models for Transfer Learning [[ICML 2021]](https://arxiv.org/pdf/2102.11005.pdf) [[Code]](/examples/model_selection/logme.py)
##### Semi-Supervised Learning for Classification [[Code]](/examples/semi_supervised_learning/image_classification/)
- **Pseudo Label** - Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks [[ICML 2013]](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf) [[Code]](/examples/semi_supervised_learning/image_classification/pseudo_label.py)
- **Pi Model** - Temporal Ensembling for Semi-Supervised Learning [[ICLR 2017]](https://arxiv.org/abs/1610.02242) [[Code]](/examples/semi_supervised_learning/image_classification/pi_model.py)
- **Mean Teacher** - Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results [[NIPS 2017]](https://arxiv.org/abs/1703.01780) [[Code]](/examples/semi_supervised_learning/image_classification/mean_teacher.py)
- **Noisy Student** - Self-Training With Noisy Student Improves ImageNet Classification [[CVPR 2020]](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xie_Self-Training_With_Noisy_Student_Improves_ImageNet_Classification_CVPR_2020_paper.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/noisy_student.py)
- **UDA** - Unsupervised Data Augmentation for Consistency Training [[NIPS 2020]](https://arxiv.org/pdf/1904.12848v4.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/uda.py)
- **FixMatch** - Simplifying Semi-Supervised Learning with Consistency and Confidence [[NIPS 2020]](https://arxiv.org/abs/2001.07685) [[Code]](/examples/semi_supervised_learning/image_classification/fixmatch.py)
- **Self-Tuning** - Self-Tuning for Data-Efficient Deep Learning [[ICML 2021]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/self_tuning.py)
- **FlexMatch** - FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling [[NIPS 2021]](https://arxiv.org/abs/2110.08263) [[Code]](/examples/semi_supervised_learning/image_classification/flexmatch.py)
- **DebiasMatch** - Debiased Learning From Naturally Imbalanced Pseudo-Labels [[CVPR 2022]](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_Debiased_Learning_From_Naturally_Imbalanced_Pseudo-Labels_CVPR_2022_paper.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/debiasmatch.py)
- **DST** - Debiased Self-Training for Semi-Supervised Learning [[NIPS 2022 Oral]](https://arxiv.org/abs/2202.07136) [[Code]](/examples/semi_supervised_learning/image_classification/dst.py)
## Installation
##### Install from Source Code
- Please git clone the library first. Then, run the following commands to install `tllib` and all the dependency.
```shell
python setup.py install
pip install -r requirements.txt
```
##### Install via `pip`
- Installing via `pip` is currently experimental.
```shell
pip install -i https://test.pypi.org/simple/ tllib==0.4
```
## Documentation
You can find the API documentation on the website: [Documentation](http://tl.thuml.ai/).
## Usage
You can find examples in the directory `examples`. A typical usage is
```shell script
# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.
# Assume you have put the datasets under the path `data/office-31`,
# or you are glad to download the datasets automatically from the Internet to this path
python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20
```
## Contributing
We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us.
## Disclaimer on Datasets
This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have licenses to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license.
If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community!
## Contact
If you have any problem with our code or have some suggestions, including the future feature, feel free to contact
- Baixu Chen (cbx_99_hasta@outlook.com)
- Junguang Jiang (JiangJunguang1123@outlook.com)
- Mingsheng Long (longmingsheng@gmail.com)
or describe it in Issues.
For Q&A in Chinese, you can choose to ask questions here before sending an email. [迁移学习算法库答疑专区](https://zhuanlan.zhihu.com/p/248104070)
## Citation
If you use this toolbox or benchmark in your research, please cite this project.
```latex
@misc{jiang2022transferability,
title={Transferability in Deep Learning: A Survey},
author={Junguang Jiang and Yang Shu and Jianmin Wang and Mingsheng Long},
year={2022},
eprint={2201.05867},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
@misc{tllib,
author = {Junguang Jiang, Baixu Chen, Bo Fu, Mingsheng Long},
title = {Transfer-Learning-library},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/thuml/Transfer-Learning-Library}},
}
```
## Acknowledgment
We would like to thank School of Software, Tsinghua University and The National Engineering Laboratory for Big Data Software for providing such an excellent ML research platform.
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SPHINXPROJ = PyTorchSphinxTheme
SOURCEDIR = .
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/conf.py
================================================
import sys
import os
sys.path.append(os.path.abspath('..'))
sys.path.append(os.path.abspath('./demo/'))
from pytorch_sphinx_theme import __version__
import pytorch_sphinx_theme
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#sys.path.insert(0, os.path.abspath('.'))
# -- General configuration -----------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = [
'sphinx.ext.intersphinx',
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinxcontrib.httpdomain',
'sphinx.ext.autosummary',
'sphinx.ext.autosectionlabel',
'sphinx.ext.napoleon',
]
# build the templated autosummary files
autosummary_generate = True
numpydoc_show_class_members = False
# autosectionlabel throws warnings if section names are duplicated.
# The following tells autosectionlabel to not throw a warning for
# duplicated section names that are in different documents.
autosectionlabel_prefix_document = True
napoleon_use_ivar = True
# Do not warn about external images (status badges in README.rst)
suppress_warnings = ['image.nonlocal_uri']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix of source filenames.
source_suffix = '.rst'
# The encoding of source files.
#source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = u'Transfer Learning Library'
copyright = u'THUML Group'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
version = __version__
# The full version, including alpha/beta/rc tags.
release = __version__
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
language = 'en'
# There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used:
#today = ''
# Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = []
# The reST default role (used for this markup: `text`) to use for all documents.
#default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
#add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
#show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
intersphinx_mapping = {
'rtd': ('https://docs.readthedocs.io/en/latest/', None),
'python': ('https://docs.python.org/3', None),
'numpy': ('https://numpy.org/doc/stable', None),
'torch': ('https://pytorch.org/docs/stable', None),
'torchvision': ('https://pytorch.org/vision/stable', None),
'PIL': ('https://pillow.readthedocs.io/en/stable/', None)
}
# -- Options for HTML output ---------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = 'pytorch_sphinx_theme'
html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()]
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
html_theme_options = {
'canonical_url': '',
'analytics_id': '',
'logo_only': False,
'display_version': False,
'prev_next_buttons_location': 'bottom',
'style_external_links': False,
# Toc options
'collapse_navigation': True,
'sticky_navigation': False,
'navigation_depth': 4,
'includehidden': True,
'titles_only': False
}
# The name for this set of Sphinx documents. If None, it defaults to
# " v documentation".
#html_title = None
# A shorter title for the navigation bar. Default is the same as html_title.
#html_short_title = None
# The name of an image file (relative to this directory) to place at the top
# of the sidebar.
html_logo = "_static/images/TransLearn.png"
# The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large.
#html_favicon = None
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format.
#html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
#html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
#html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names to
# template names.
#html_additional_pages = {}
# If false, no module index is generated.
#html_domain_indices = True
# If false, no index is generated.
#html_use_index = True
# If true, the index is split into individual pages for each letter.
#html_split_index = False
# If true, links to the reST sources are added to the pages.
html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
#html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
#html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will
# contain a tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served.
#html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
#html_file_suffix = None
# Disable displaying type annotations, these can be very verbose
autodoc_typehints = 'none'
# Output file base name for HTML help builder.
htmlhelp_basename = 'TransferLearningLibrary'
# -- Options for LaTeX output --------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [
('index', 'TransferLearningLibrary.tex', u'Transfer Learning Library Documentation',
u'THUML', 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
# the title page.
#latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters.
#latex_use_parts = False
# If true, show page references after internal links.
#latex_show_pagerefs = False
# If true, show URL addresses after external links.
#latex_show_urls = False
# Documents to append as an appendix to all manuals.
#latex_appendices = []
# If false, no module index is generated.
#latex_domain_indices = True
# -- Options for manual page output --------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
('index', 'Transfer Learning Library', u'Transfer Learning Library Documentation',
[u'THUML'], 1)
]
# If true, show URL addresses after external links.
#man_show_urls = False
# -- Options for Texinfo output ------------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
('index', 'Transfer Learning Library', u'Transfer Learning Library Documentation',
u'THUML', 'Transfer Learning Library',
'One line description of project.', 'Miscellaneous'),
]
# Documents to append as an appendix to all manuals.
#texinfo_appendices = []
# If false, no module index is generated.
#texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'.
#texinfo_show_urls = 'footnote'
================================================
FILE: docs/index.rst
================================================
=====================================
Transfer Learning
=====================================
.. toctree::
:maxdepth: 2
:caption: Transfer Learning API
:titlesonly:
tllib/modules
tllib/alignment/index
tllib/translation
tllib/self_training
tllib/reweight
tllib/normalization
tllib/regularization
tllib/ranking
.. toctree::
:maxdepth: 2
:caption: Common API
:titlesonly:
tllib/vision/index
tllib/utils/index
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=python -msphinx
)
set SPHINXOPTS=
set SPHINXBUILD=sphinx-build
set SOURCEDIR=.
set BUILDDIR=build
set SPHINXPROJ=PyTorchSphinxTheme
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The Sphinx module was not found. Make sure you have Sphinx installed,
echo.then set the SPHINXBUILD environment variable to point to the full
echo.path of the 'sphinx-build' executable. Alternatively you may add the
echo.Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd
================================================
FILE: docs/requirements.txt
================================================
sphinxcontrib-httpdomain
sphinx
================================================
FILE: docs/tllib/alignment/domain_adversarial.rst
================================================
==========================================
Domain Adversarial Training
==========================================
.. _DANN:
DANN: Domain Adversarial Neural Network
----------------------------------------
.. autoclass:: tllib.alignment.dann.DomainAdversarialLoss
.. _CDAN:
CDAN: Conditional Domain Adversarial Network
-----------------------------------------------
.. autoclass:: tllib.alignment.cdan.ConditionalDomainAdversarialLoss
.. autoclass:: tllib.alignment.cdan.RandomizedMultiLinearMap
.. autoclass:: tllib.alignment.cdan.MultiLinearMap
.. _ADDA:
ADDA: Adversarial Discriminative Domain Adaptation
-----------------------------------------------------
.. autoclass:: tllib.alignment.adda.DomainAdversarialLoss
.. note::
ADDAgrl is also implemented and benchmarked. You can find code
`here `_.
.. _BSP:
BSP: Batch Spectral Penalization
-----------------------------------
.. autoclass:: tllib.alignment.bsp.BatchSpectralPenalizationLoss
.. _OSBP:
OSBP: Open Set Domain Adaptation by Backpropagation
----------------------------------------------------
.. autoclass:: tllib.alignment.osbp.UnknownClassBinaryCrossEntropy
.. _ADVENT:
ADVENT: Adversarial Entropy Minimization for Semantic Segmentation
------------------------------------------------------------------
.. autoclass:: tllib.alignment.advent.Discriminator
.. autoclass:: tllib.alignment.advent.DomainAdversarialEntropyLoss
:members:
.. _DADAPT:
D-adapt: Decoupled Adaptation for Cross-Domain Object Detection
----------------------------------------------------------------
`Origin Paper `_.
.. autoclass:: tllib.alignment.d_adapt.proposal.Proposal
.. autoclass:: tllib.alignment.d_adapt.proposal.PersistentProposalList
.. autoclass:: tllib.alignment.d_adapt.proposal.ProposalDataset
.. autoclass:: tllib.alignment.d_adapt.modeling.meta_arch.DecoupledGeneralizedRCNN
.. autoclass:: tllib.alignment.d_adapt.modeling.meta_arch.DecoupledRetinaNet
================================================
FILE: docs/tllib/alignment/hypothesis_adversarial.rst
================================================
==========================================
Hypothesis Adversarial Learning
==========================================
.. _MCD:
MCD: Maximum Classifier Discrepancy
--------------------------------------------
.. autofunction:: tllib.alignment.mcd.classifier_discrepancy
.. autofunction:: tllib.alignment.mcd.entropy
.. autoclass:: tllib.alignment.mcd.ImageClassifierHead
.. _MDD:
MDD: Margin Disparity Discrepancy
--------------------------------------------
.. autoclass:: tllib.alignment.mdd.MarginDisparityDiscrepancy
**MDD for Classification**
.. autoclass:: tllib.alignment.mdd.ClassificationMarginDisparityDiscrepancy
.. autoclass:: tllib.alignment.mdd.ImageClassifier
:members:
.. autofunction:: tllib.alignment.mdd.shift_log
**MDD for Regression**
.. autoclass:: tllib.alignment.mdd.RegressionMarginDisparityDiscrepancy
.. autoclass:: tllib.alignment.mdd.ImageRegressor
.. _RegDA:
RegDA: Regressive Domain Adaptation
--------------------------------------------
.. autoclass:: tllib.alignment.regda.PseudoLabelGenerator2d
.. autoclass:: tllib.alignment.regda.RegressionDisparity
.. autoclass:: tllib.alignment.regda.PoseResNet2d
================================================
FILE: docs/tllib/alignment/index.rst
================================================
=====================================
Feature Alignment
=====================================
.. toctree::
:maxdepth: 3
:caption: Feature Alignment
:titlesonly:
statistics_matching
domain_adversarial
hypothesis_adversarial
================================================
FILE: docs/tllib/alignment/statistics_matching.rst
================================================
=====================
Statistics Matching
=====================
.. _DAN:
DAN: Deep Adaptation Network
-----------------------------
.. autoclass:: tllib.alignment.dan.MultipleKernelMaximumMeanDiscrepancy
.. _CORAL:
Deep CORAL: Correlation Alignment for Deep Domain Adaptation
--------------------------------------------------------------
.. autoclass:: tllib.alignment.coral.CorrelationAlignmentLoss
.. _JAN:
JAN: Joint Adaptation Network
------------------------------
.. autoclass:: tllib.alignment.jan.JointMultipleKernelMaximumMeanDiscrepancy
================================================
FILE: docs/tllib/modules.rst
================================================
=====================
Modules
=====================
Classifier
-------------------------------
.. autoclass:: tllib.modules.classifier.Classifier
:members:
Regressor
-------------------------------
.. autoclass:: tllib.modules.regressor.Regressor
:members:
Domain Discriminator
-------------------------------
.. autoclass:: tllib.modules.domain_discriminator.DomainDiscriminator
:members:
GRL: Gradient Reverse Layer
-----------------------------
.. autoclass:: tllib.modules.grl.WarmStartGradientReverseLayer
:members:
Gaussian Kernels
------------------------
.. autoclass:: tllib.modules.kernels.GaussianKernel
Entropy
------------------------
.. autofunction:: tllib.modules.entropy.entropy
Knowledge Distillation Loss
-------------------------------
.. autoclass:: tllib.modules.loss.KnowledgeDistillationLoss
:members:
================================================
FILE: docs/tllib/normalization.rst
================================================
=====================
Normalization
=====================
.. _AFN:
AFN: Adaptive Feature Norm
-----------------------------
.. autoclass:: tllib.normalization.afn.AdaptiveFeatureNorm
.. autoclass:: tllib.normalization.afn.Block
.. autoclass:: tllib.normalization.afn.ImageClassifier
StochNorm: Stochastic Normalization
------------------------------------------
.. autoclass:: tllib.normalization.stochnorm.StochNorm1d
.. autoclass:: tllib.normalization.stochnorm.StochNorm2d
.. autoclass:: tllib.normalization.stochnorm.StochNorm3d
.. autofunction:: tllib.normalization.stochnorm.convert_model
.. _IBN:
IBN-Net: Instance-Batch Normalization Network
------------------------------------------------
.. autoclass:: tllib.normalization.ibn.InstanceBatchNorm2d
.. autoclass:: tllib.normalization.ibn.IBNNet
:members:
.. automodule:: tllib.normalization.ibn
:members:
.. _MIXSTYLE:
MixStyle: Domain Generalization with MixStyle
-------------------------------------------------
.. autoclass:: tllib.normalization.mixstyle.MixStyle
.. note::
MixStyle is only activated during `training` stage, with some probability :math:`p`.
.. automodule:: tllib.normalization.mixstyle.resnet
:members:
================================================
FILE: docs/tllib/ranking.rst
================================================
=====================
Ranking
=====================
.. _H_score:
H-score
-------------------------------------------
.. autofunction:: tllib.ranking.hscore.h_score
.. _LEEP:
LEEP: Log Expected Empirical Prediction
-------------------------------------------
.. autofunction:: tllib.ranking.leep.log_expected_empirical_prediction
.. _NCE:
NCE: Negative Conditional Entropy
-------------------------------------------
.. autofunction:: tllib.ranking.nce.negative_conditional_entropy
.. _LogME:
LogME: Log Maximum Evidence
-------------------------------------------
.. autofunction:: tllib.ranking.logme.log_maximum_evidence
================================================
FILE: docs/tllib/regularization.rst
================================================
===========================================
Regularization
===========================================
.. _L2:
L2
------
.. autoclass:: tllib.regularization.delta.L2Regularization
.. _L2SP:
L2-SP
------
.. autoclass:: tllib.regularization.delta.SPRegularization
.. _DELTA:
DELTA: DEep Learning Transfer using Feature Map with Attention
-------------------------------------------------------------------------------------
.. autoclass:: tllib.regularization.delta.BehavioralRegularization
.. autoclass:: tllib.regularization.delta.AttentionBehavioralRegularization
.. autoclass:: tllib.regularization.delta.IntermediateLayerGetter
.. _LWF:
LWF: Learning without Forgetting
------------------------------------------
.. autoclass:: tllib.regularization.lwf.Classifier
.. _CoTuning:
Co-Tuning
------------------------------------------
.. autoclass:: tllib.regularization.co_tuning.CoTuningLoss
.. autoclass:: tllib.regularization.co_tuning.Relationship
.. _StochNorm:
.. _BiTuning:
Bi-Tuning
------------------------------------------
.. autoclass:: tllib.regularization.bi_tuning.BiTuning
.. _BSS:
BSS: Batch Spectral Shrinkage
------------------------------------------
.. autoclass:: tllib.regularization.bss.BatchSpectralShrinkage
================================================
FILE: docs/tllib/reweight.rst
================================================
=======================================
Re-weighting
=======================================
.. _PADA:
PADA: Partial Adversarial Domain Adaptation
---------------------------------------------
.. autoclass:: tllib.reweight.pada.ClassWeightModule
.. autoclass:: tllib.reweight.pada.AutomaticUpdateClassWeightModule
:members:
.. autofunction:: tllib.reweight.pada.collect_classification_results
.. _IWAN:
IWAN: Importance Weighted Adversarial Nets
---------------------------------------------
.. autoclass:: tllib.reweight.iwan.ImportanceWeightModule
:members:
.. _GroupDRO:
GroupDRO: Group Distributionally robust optimization
------------------------------------------------------
.. autoclass:: tllib.reweight.groupdro.AutomaticUpdateDomainWeightModule
:members:
================================================
FILE: docs/tllib/self_training.rst
================================================
=======================================
Self Training Methods
=======================================
.. _PseudoLabel:
Pseudo Label
-----------------------------
.. autoclass:: tllib.self_training.pseudo_label.ConfidenceBasedSelfTrainingLoss
.. _PiModel:
:math:`\Pi` Model
-----------------------------
.. autoclass:: tllib.self_training.pi_model.ConsistencyLoss
.. autoclass:: tllib.self_training.pi_model.L2ConsistencyLoss
.. _MeanTeacher:
Mean Teacher
-----------------------------
.. autoclass:: tllib.self_training.mean_teacher.EMATeacher
.. _SelfEnsemble:
Self Ensemble
-----------------------------
.. autoclass:: tllib.self_training.self_ensemble.ClassBalanceLoss
.. _UDA:
UDA
-----------------------------
.. autoclass:: tllib.self_training.uda.StrongWeakConsistencyLoss
.. _MCC:
MCC: Minimum Class Confusion
-----------------------------
.. autoclass:: tllib.self_training.mcc.MinimumClassConfusionLoss
.. _MMT:
MMT: Mutual Mean-Teaching
--------------------------
`Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised
Domain Adaptation on Person Re-identification (ICLR 2020) `_
State of the art unsupervised domain adaptation methods utilize clustering algorithms to generate pseudo labels on target
domain, which are noisy and thus harmful for training. Inspired by the teacher-student approaches, MMT framework
provides robust soft pseudo labels in an on-line peer-teaching manner.
We denote two networks as :math:`f_1,f_2`, their parameters as :math:`\theta_1,\theta_2`. The authors also
propose to use the temporally average model of each network :math:`\text{ensemble}(f_1),\text{ensemble}(f_2)` to generate more reliable
soft pseudo labels for supervising the other network. Specifically, the parameters of the temporally
average models of the two networks at current iteration :math:`T` are denoted as :math:`E^{(T)}[\theta_1]` and
:math:`E^{(T)}[\theta_2]` respectively, which can be calculated as
.. math::
E^{(T)}[\theta_1] = \alpha E^{(T-1)}[\theta_1] + (1-\alpha)\theta_1
.. math::
E^{(T)}[\theta_2] = \alpha E^{(T-1)}[\theta_2] + (1-\alpha)\theta_2
where :math:`E^{(T-1)}[\theta_1],E^{(T-1)}[\theta_2]` indicate the temporal average parameters of the two networks in
the previous iteration :math:`(T-1)`, the initial temporal average parameters are
:math:`E^{(0)}[\theta_1]=\theta_1,E^{(0)}[\theta_2]=\theta_2` and :math:`\alpha` is the momentum.
These two networks cooperate with each other in three ways:
- When running clustering algorithm, we average features produced by :math:`\text{ensemble}(f_1)` and
:math:`\text{ensemble}(f_2)` instead of only considering one of them.
- A **soft triplet loss** is optimized between :math:`f_1` and :math:`\text{ensemble}(f_2)` and vice versa
to force one network to learn from temporally average of another network.
- A **cross entropy loss** is optimized between :math:`f_1` and :math:`\text{ensemble}(f_2)` and vice versa
to force one network to learn from temporally average of another network.
The above mentioned loss functions are listed below, more details can be found in training scripts.
.. autoclass:: tllib.vision.models.reid.loss.SoftTripletLoss
.. autoclass:: tllib.vision.models.reid.loss.CrossEntropyLoss
.. _SelfTuning:
Self Tuning
-----------------------------
.. autoclass:: tllib.self_training.self_tuning.Classifier
.. autoclass:: tllib.self_training.self_tuning.SelfTuning
.. _FlexMatch:
FlexMatch
-----------------------------
.. autoclass:: tllib.self_training.flexmatch.DynamicThresholdingModule
:members:
.. _DST:
Debiased Self-Training
-----------------------------
.. autoclass:: tllib.self_training.dst.ImageClassifier
.. autoclass:: tllib.self_training.dst.WorstCaseEstimationLoss
================================================
FILE: docs/tllib/translation.rst
================================================
=======================================
Domain Translation
=======================================
.. _CycleGAN:
------------------------------------------------
CycleGAN: Cycle-Consistent Adversarial Networks
------------------------------------------------
Discriminator
--------------
.. autofunction:: tllib.translation.cyclegan.pixel
.. autofunction:: tllib.translation.cyclegan.patch
Generator
--------------
.. autofunction:: tllib.translation.cyclegan.resnet_9
.. autofunction:: tllib.translation.cyclegan.resnet_6
.. autofunction:: tllib.translation.cyclegan.unet_256
.. autofunction:: tllib.translation.cyclegan.unet_128
GAN Loss
--------------
.. autoclass:: tllib.translation.cyclegan.LeastSquaresGenerativeAdversarialLoss
.. autoclass:: tllib.translation.cyclegan.VanillaGenerativeAdversarialLoss
.. autoclass:: tllib.translation.cyclegan.WassersteinGenerativeAdversarialLoss
Translation
--------------
.. autoclass:: tllib.translation.cyclegan.Translation
Util
----------------
.. autoclass:: tllib.translation.cyclegan.util.ImagePool
:members:
.. autofunction:: tllib.translation.cyclegan.util.set_requires_grad
.. _Cycada:
--------------------------------------------------------------
CyCADA: Cycle-Consistent Adversarial Domain Adaptation
--------------------------------------------------------------
.. autoclass:: tllib.translation.cycada.SemanticConsistency
.. _SPGAN:
-----------------------------------------------------------
SPGAN: Similarity Preserving Generative Adversarial Network
-----------------------------------------------------------
`Image-Image Domain Adaptation with Preserved Self-Similarity and Domain-Dissimilarity for Person Re-identification
`_. SPGAN is based on CycleGAN. An additional Siamese network is adopted to force
the generator to produce images different from identities in target dataset.
Siamese Network
-------------------
.. autoclass:: tllib.translation.spgan.siamese.SiameseNetwork
Contrastive Loss
-------------------
.. autoclass:: tllib.translation.spgan.loss.ContrastiveLoss
.. _FDA:
------------------------------------------------
FDA: Fourier Domain Adaptation
------------------------------------------------
.. autoclass:: tllib.translation.fourier_transform.FourierTransform
.. autofunction:: tllib.translation.fourier_transform.low_freq_mutate
================================================
FILE: docs/tllib/utils/analysis.rst
================================================
==============
Analysis Tools
==============
.. autofunction:: tllib.utils.analysis.collect_feature
.. autofunction:: tllib.utils.analysis.a_distance.calculate
.. autofunction:: tllib.utils.analysis.tsne.visualize
================================================
FILE: docs/tllib/utils/base.rst
================================================
Generic Tools
==============
Average Meter
---------------------------------
.. autoclass:: tllib.utils.meter.AverageMeter
:members:
Progress Meter
---------------------------------
.. autoclass:: tllib.utils.meter.ProgressMeter
:members:
Meter
---------------------------------
.. autoclass:: tllib.utils.meter.Meter
:members:
Data
---------------------------------
.. autoclass:: tllib.utils.data.ForeverDataIterator
:members:
.. autoclass:: tllib.utils.data.CombineDataset
:members:
.. autofunction:: tllib.utils.data.send_to_device
.. autofunction:: tllib.utils.data.concatenate
Logger
-----------
.. autoclass:: tllib.utils.logger.TextLogger
:members:
.. autoclass:: tllib.utils.logger.CompleteLogger
:members:
================================================
FILE: docs/tllib/utils/index.rst
================================================
=====================================
Utilities
=====================================
.. toctree::
:maxdepth: 2
:caption: Utilities
:titlesonly:
base
metric
analysis
================================================
FILE: docs/tllib/utils/metric.rst
================================================
===========
Metrics
===========
Classification & Segmentation
==============================
Accuracy
---------------------------------
.. autofunction:: tllib.utils.metric.accuracy
ConfusionMatrix
---------------------------------
.. autoclass:: tllib.utils.metric.ConfusionMatrix
:members:
================================================
FILE: docs/tllib/vision/datasets.rst
================================================
Datasets
=============================
Cross-Domain Classification
---------------------------------------------------------
--------------------------------------
ImageList
--------------------------------------
.. autoclass:: tllib.vision.datasets.imagelist.ImageList
:members:
-------------------------------------
Office-31
-------------------------------------
.. autoclass:: tllib.vision.datasets.office31.Office31
:members:
:inherited-members:
---------------------------------------
Office-Caltech
---------------------------------------
.. autoclass:: tllib.vision.datasets.officecaltech.OfficeCaltech
:members:
:inherited-members:
---------------------------------------
Office-Home
---------------------------------------
.. autoclass:: tllib.vision.datasets.officehome.OfficeHome
:members:
:inherited-members:
--------------------------------------
VisDA-2017
--------------------------------------
.. autoclass:: tllib.vision.datasets.visda2017.VisDA2017
:members:
:inherited-members:
--------------------------------------
DomainNet
--------------------------------------
.. autoclass:: tllib.vision.datasets.domainnet.DomainNet
:members:
:inherited-members:
--------------------------------------
PACS
--------------------------------------
.. autoclass:: tllib.vision.datasets.pacs.PACS
:members:
--------------------------------------
MNIST
--------------------------------------
.. autoclass:: tllib.vision.datasets.digits.MNIST
:members:
--------------------------------------
USPS
--------------------------------------
.. autoclass:: tllib.vision.datasets.digits.USPS
:members:
--------------------------------------
SVHN
--------------------------------------
.. autoclass:: tllib.vision.datasets.digits.SVHN
:members:
Partial Cross-Domain Classification
----------------------------------------------------
---------------------------------------
Partial Wrapper
---------------------------------------
.. autofunction:: tllib.vision.datasets.partial.partial
.. autofunction:: tllib.vision.datasets.partial.default_partial
---------------------------------------
Caltech-256->ImageNet-1k
---------------------------------------
.. autoclass:: tllib.vision.datasets.partial.caltech_imagenet.CaltechImageNet
:members:
---------------------------------------
ImageNet-1k->Caltech-256
---------------------------------------
.. autoclass:: tllib.vision.datasets.partial.imagenet_caltech.ImageNetCaltech
:members:
Open Set Cross-Domain Classification
------------------------------------------------------
---------------------------------------
Open Set Wrapper
---------------------------------------
.. autofunction:: tllib.vision.datasets.openset.open_set
.. autofunction:: tllib.vision.datasets.openset.default_open_set
Cross-Domain Regression
------------------------------------------------------
---------------------------------------
ImageRegression
---------------------------------------
.. autoclass:: tllib.vision.datasets.regression.image_regression.ImageRegression
:members:
---------------------------------------
DSprites
---------------------------------------
.. autoclass:: tllib.vision.datasets.regression.dsprites.DSprites
:members:
---------------------------------------
MPI3D
---------------------------------------
.. autoclass:: tllib.vision.datasets.regression.mpi3d.MPI3D
:members:
Cross-Domain Segmentation
-----------------------------------------------
---------------------------------------
SegmentationList
---------------------------------------
.. autoclass:: tllib.vision.datasets.segmentation.segmentation_list.SegmentationList
:members:
---------------------------------------
Cityscapes
---------------------------------------
.. autoclass:: tllib.vision.datasets.segmentation.cityscapes.Cityscapes
---------------------------------------
GTA5
---------------------------------------
.. autoclass:: tllib.vision.datasets.segmentation.gta5.GTA5
---------------------------------------
Synthia
---------------------------------------
.. autoclass:: tllib.vision.datasets.segmentation.synthia.Synthia
---------------------------------------
Foggy Cityscapes
---------------------------------------
.. autoclass:: tllib.vision.datasets.segmentation.cityscapes.FoggyCityscapes
Cross-Domain Keypoint Detection
-----------------------------------------------
---------------------------------------
Dataset Base for Keypoint Detection
---------------------------------------
.. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.KeypointDataset
:members:
.. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.Body16KeypointDataset
:members:
.. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.Hand21KeypointDataset
:members:
---------------------------------------
Rendered Handpose Dataset
---------------------------------------
.. autoclass:: tllib.vision.datasets.keypoint_detection.rendered_hand_pose.RenderedHandPose
:members:
---------------------------------------
Hand-3d-Studio Dataset
---------------------------------------
.. autoclass:: tllib.vision.datasets.keypoint_detection.hand_3d_studio.Hand3DStudio
:members:
---------------------------------------
FreiHAND Dataset
---------------------------------------
.. autoclass:: tllib.vision.datasets.keypoint_detection.freihand.FreiHand
:members:
---------------------------------------
Surreal Dataset
---------------------------------------
.. autoclass:: tllib.vision.datasets.keypoint_detection.surreal.SURREAL
:members:
---------------------------------------
LSP Dataset
---------------------------------------
.. autoclass:: tllib.vision.datasets.keypoint_detection.lsp.LSP
:members:
---------------------------------------
Human3.6M Dataset
---------------------------------------
.. autoclass:: tllib.vision.datasets.keypoint_detection.human36m.Human36M
:members:
Cross-Domain ReID
------------------------------------------------------
---------------------------------------
Market1501
---------------------------------------
.. autoclass:: tllib.vision.datasets.reid.market1501.Market1501
:members:
---------------------------------------
DukeMTMC-reID
---------------------------------------
.. autoclass:: tllib.vision.datasets.reid.dukemtmc.DukeMTMC
:members:
---------------------------------------
MSMT17
---------------------------------------
.. autoclass:: tllib.vision.datasets.reid.msmt17.MSMT17
:members:
Natural Object Recognition
---------------------------------------------------------
-------------------------------------
Stanford Dogs
-------------------------------------
.. autoclass:: tllib.vision.datasets.stanford_dogs.StanfordDogs
:members:
-------------------------------------
Stanford Cars
-------------------------------------
.. autoclass:: tllib.vision.datasets.stanford_cars.StanfordCars
:members:
-------------------------------------
CUB-200-2011
-------------------------------------
.. autoclass:: tllib.vision.datasets.cub200.CUB200
:members:
-------------------------------------
FVGC Aircraft
-------------------------------------
.. autoclass:: tllib.vision.datasets.aircrafts.Aircraft
:members:
-------------------------------------
Oxford-IIIT Pets
-------------------------------------
.. autoclass:: tllib.vision.datasets.oxfordpets.OxfordIIITPets
:members:
-------------------------------------
COCO-70
-------------------------------------
.. autoclass:: tllib.vision.datasets.coco70.COCO70
:members:
-------------------------------------
DTD
-------------------------------------
.. autoclass:: tllib.vision.datasets.dtd.DTD
:members:
-------------------------------------
OxfordFlowers102
-------------------------------------
.. autoclass:: tllib.vision.datasets.oxfordflowers.OxfordFlowers102
:members:
-------------------------------------
Caltech101
-------------------------------------
.. autoclass:: tllib.vision.datasets.caltech101.Caltech101
:members:
Specialized Image Classification
--------------------------------
-------------------------------------
PatchCamelyon
-------------------------------------
.. autoclass:: tllib.vision.datasets.patchcamelyon.PatchCamelyon
:members:
-------------------------------------
Retinopathy
-------------------------------------
.. autoclass:: tllib.vision.datasets.retinopathy.Retinopathy
:members:
-------------------------------------
EuroSAT
-------------------------------------
.. autoclass:: tllib.vision.datasets.eurosat.EuroSAT
:members:
-------------------------------------
Resisc45
-------------------------------------
.. autoclass:: tllib.vision.datasets.resisc45.Resisc45
:members:
-------------------------------------
Food-101
-------------------------------------
.. autoclass:: tllib.vision.datasets.food101.Food101
:members:
-------------------------------------
SUN397
-------------------------------------
.. autoclass:: tllib.vision.datasets.sun397.SUN397
:members:
================================================
FILE: docs/tllib/vision/index.rst
================================================
=====================================
Vision
=====================================
.. toctree::
:maxdepth: 2
:caption: Vision
:titlesonly:
datasets
models
transforms
================================================
FILE: docs/tllib/vision/models.rst
================================================
Models
===========================
------------------------------
Image Classification
------------------------------
ResNets
---------------------------------
.. automodule:: tllib.vision.models.resnet
:members:
LeNet
--------------------------
.. automodule:: tllib.vision.models.digits.lenet
:members:
DTN
--------------------------
.. automodule:: tllib.vision.models.digits.dtn
:members:
----------------------------------
Object Detection
----------------------------------
.. autoclass:: tllib.vision.models.object_detection.meta_arch.TLGeneralizedRCNN
:members:
.. autoclass:: tllib.vision.models.object_detection.meta_arch.TLRetinaNet
:members:
.. autoclass:: tllib.vision.models.object_detection.proposal_generator.rpn.TLRPN
.. autoclass:: tllib.vision.models.object_detection.roi_heads.TLRes5ROIHeads
:members:
.. autoclass:: tllib.vision.models.object_detection.roi_heads.TLStandardROIHeads
:members:
----------------------------------
Semantic Segmentation
----------------------------------
.. autofunction:: tllib.vision.models.segmentation.deeplabv2.deeplabv2_resnet101
----------------------------------
Keypoint Detection
----------------------------------
PoseResNet
--------------------------
.. autofunction:: tllib.vision.models.keypoint_detection.pose_resnet.pose_resnet101
.. autoclass:: tllib.vision.models.keypoint_detection.pose_resnet.PoseResNet
.. autoclass:: tllib.vision.models.keypoint_detection.pose_resnet.Upsampling
Joint Loss
----------------------------------
.. autoclass:: tllib.vision.models.keypoint_detection.loss.JointsMSELoss
.. autoclass:: tllib.vision.models.keypoint_detection.loss.JointsKLLoss
-----------------------------------
Re-Identification
-----------------------------------
Models
---------------
.. autoclass:: tllib.vision.models.reid.resnet.ReidResNet
.. automodule:: tllib.vision.models.reid.resnet
:members:
.. autoclass:: tllib.vision.models.reid.identifier.ReIdentifier
:members:
Loss
-----------------------------------
.. autoclass:: tllib.vision.models.reid.loss.TripletLoss
Sampler
-----------------------------------
.. autoclass:: tllib.utils.data.RandomMultipleGallerySampler
================================================
FILE: docs/tllib/vision/transforms.rst
================================================
Transforms
=============================
Classification
---------------------------------
.. automodule:: tllib.vision.transforms
:members:
Segmentation
---------------------------------
.. automodule:: tllib.vision.transforms.segmentation
:members:
Keypoint Detection
---------------------------------
.. automodule:: tllib.vision.transforms.keypoint_detection
:members:
================================================
FILE: examples/domain_adaptation/image_classification/README.md
================================================
# Unsupervised Domain Adaptation for Image Classification
## Installation
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You
also need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
Following datasets can be downloaded automatically:
- [MNIST](http://yann.lecun.com/exdb/mnist/), [SVHN](http://ufldl.stanford.edu/housenumbers/)
, [USPS](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps)
- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)
- [OfficeCaltech](https://www.cc.gatech.edu/~judy/domainadapt/)
- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)
- [VisDA2017](http://ai.bu.edu/visda-2017/)
- [DomainNet](http://ai.bu.edu/M3SDA/)
You need to prepare following datasets manually if you want to use them:
- [ImageNet](https://www.image-net.org/)
- [ImageNetR](https://github.com/hendrycks/imagenet-r)
- [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch)
and prepare them following [Documentation for ImageNetR](/common/vision/datasets/imagenet_r.py)
and [ImageNet-Sketch](/common/vision/datasets/imagenet_sketch.py).
## Supported Methods
Supported methods include:
- [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818)
- [Deep Adaptation Network (DAN)](https://arxiv.org/pdf/1502.02791)
- [Joint Adaptation Network (JAN)](https://arxiv.org/abs/1605.06636)
- [Adversarial Discriminative Domain Adaptation (ADDA)](https://arxiv.org/pdf/1702.05464.pdf)
- [Conditional Domain Adversarial Network (CDAN)](https://arxiv.org/abs/1705.10667)
- [Maximum Classifier Discrepancy (MCD)](https://arxiv.org/abs/1712.02560)
- [Adaptive Feature Norm (AFN)](https://arxiv.org/pdf/1811.07456v2.pdf)
- [Batch Spectral Penalization (BSP)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/batch-spectral-penalization-icml19.pdf)
- [Margin Disparity Discrepancy (MDD)](https://arxiv.org/abs/1904.05801)
- [Minimum Class Confusion (MCC)](https://arxiv.org/abs/1912.03699)
- [FixMatch](https://arxiv.org/abs/2001.07685)
## Usage
The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to
train DANN on Office31, use the following script
```shell script
# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.
# Assume you have put the datasets under the path `data/office-31`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W
```
Note that ``-s`` specifies the source domain, ``-t`` specifies the target domain, and ``--log`` specifies where to store
results.
After running the above command, it will download ``Office-31`` datasets from the Internet if it's the first time you
run the code. Directory that stores datasets will be named as
``examples/domain_adaptation/image_classification/data/``.
If everything works fine, you will see results in following format::
Epoch: [1][ 900/1000] Time 0.60 ( 0.69) Data 0.22 ( 0.31) Loss 0.74 ( 0.85) Cls Acc 96.9 (95.1) Domain Acc 64.1 (62.6)
You can also watch these results in the log file ``logs/dann/Office31_A2W/log.txt``.
After training, you can test your algorithm's performance by passing in ``--phase test``.
```
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W --phase test
```
## Experiment and Results
**Notations**
- ``Origin`` means the accuracy reported by the original paper.
- ``Avg`` is the accuracy reported by `TLlib`.
- ``ERM`` refers to the model trained with data from the source domain.
- ``Oracle`` refers to the model trained with data from the target domain.
We found that the accuracies of adversarial methods (including DANN, ADDA, CDAN, MCD, BSP and MDD) are not stable even
after the random seed is fixed, thus we repeat running adversarial methods on *Office-31* and *VisDA-2017*
for three times and report their average accuracy.
### Office-31 accuracy on ResNet-50
| Methods | Origin | Avg | A → W | D → W | W → D | A → D | D → A | W → A |
|---------|--------|------|-------|-------|-------|-------|-------|-------|
| ERM | 76.1 | 79.5 | 75.8 | 95.5 | 99.0 | 79.3 | 63.6 | 63.8 |
| DANN | 82.2 | 86.1 | 91.4 | 97.9 | 100.0 | 83.6 | 73.3 | 70.4 |
| ADDA | / | 87.3 | 94.6 | 97.5 | 99.7 | 90.0 | 69.6 | 72.5 |
| BSP | 87.7 | 87.8 | 92.7 | 97.9 | 100.0 | 88.2 | 74.1 | 73.8 |
| DAN | 80.4 | 83.7 | 84.2 | 98.4 | 100.0 | 87.3 | 66.9 | 65.2 |
| JAN | 84.3 | 87.0 | 93.7 | 98.4 | 100.0 | 89.4 | 69.2 | 71.0 |
| CDAN | 87.7 | 87.7 | 93.8 | 98.5 | 100.0 | 89.9 | 73.4 | 70.4 |
| MCD | / | 85.4 | 90.4 | 98.5 | 100.0 | 87.3 | 68.3 | 67.6 |
| AFN | 85.7 | 88.6 | 94.0 | 98.9 | 100.0 | 94.4 | 72.9 | 71.1 |
| MDD | 88.9 | 89.6 | 95.6 | 98.6 | 100.0 | 94.4 | 76.6 | 72.2 |
| MCC | 89.4 | 89.6 | 94.1 | 98.4 | 99.8 | 95.6 | 75.5 | 74.2 |
| FixMatch| / | 86.4 | 86.4 | 98.2 | 100.0 | 95.4 | 70.0 | 68.1 |
### Office-Home accuracy on ResNet-50
| Methods | Origin | Avg | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr |
|-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|
| ERM | 46.1 | 58.4 | 41.1 | 65.9 | 73.7 | 53.1 | 60.1 | 63.3 | 52.2 | 36.7 | 71.8 | 64.8 | 42.6 | 75.2 |
| DAN | 56.3 | 61.4 | 45.6 | 67.7 | 73.9 | 57.7 | 63.8 | 66.0 | 54.9 | 40.0 | 74.5 | 66.2 | 49.1 | 77.9 |
| DANN | 57.6 | 65.2 | 53.8 | 62.6 | 74.0 | 55.8 | 67.3 | 67.3 | 55.8 | 55.1 | 77.9 | 71.1 | 60.7 | 81.1 |
| ADDA | / | 65.6 | 52.6 | 62.9 | 74.0 | 59.7 | 68.0 | 68.8 | 61.4 | 52.5 | 77.6 | 71.1 | 58.6 | 80.2 |
| JAN | 58.3 | 65.9 | 50.8 | 71.9 | 76.5 | 60.6 | 68.3 | 68.7 | 60.5 | 49.6 | 76.9 | 71.0 | 55.9 | 80.5 |
| CDAN | 65.8 | 68.8 | 55.2 | 72.4 | 77.6 | 62.0 | 69.7 | 70.9 | 62.4 | 54.3 | 80.5 | 75.5 | 61.0 | 83.8 |
| MCD | / | 67.8 | 51.7 | 72.2 | 78.2 | 63.7 | 69.5 | 70.8 | 61.5 | 52.8 | 78.0 | 74.5 | 58.4 | 81.8 |
| BSP | 64.9 | 67.6 | 54.7 | 67.7 | 76.2 | 61.0 | 69.4 | 70.9 | 60.9 | 55.2 | 80.2 | 73.4 | 60.3 | 81.2 |
| AFN | 67.3 | 68.2 | 53.2 | 72.7 | 76.8 | 65.0 | 71.3 | 72.3 | 65.0 | 51.4 | 77.9 | 72.3 | 57.8 | 82.4 |
| MDD | 68.1 | 69.7 | 56.2 | 75.4 | 79.6 | 63.5 | 72.1 | 73.8 | 62.5 | 54.8 | 79.9 | 73.5 | 60.9 | 84.5 |
| MCC | / | 72.4 | 58.4 | 79.6 | 83.0 | 67.5 | 77.0 | 78.5 | 66.6 | 54.8 | 81.8 | 74.4 | 61.4 | 85.6 |
| FixMatch | / | 70.8 | 56.4 | 76.4 | 79.9 | 65.3 | 73.8 | 71.2 | 67.2 | 56.4 | 80.6 | 74.9 | 63.5 | 84.3 |
### Office-Home accuracy on vit_base_patch16_224 (batch size 24)
| Methods | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr | Avg |
|-------------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|------|
| Source Only | 52.4 | 82.1 | 86.9 | 76.8 | 84.1 | 86 | 75.1 | 51.2 | 88.1 | 78.3 | 51.5 | 87.8 | 75.0 |
| DANN | 60.1 | 80.8 | 87.9 | 78.1 | 82.6 | 85.9 | 78.8 | 63.2 | 90.2 | 82.3 | 64 | 89.3 | 78.6 |
| DAN | 56.3 | 83.6 | 87.5 | 77.7 | 84.7 | 86.7 | 75.9 | 54.5 | 88.5 | 80.2 | 56.2 | 88.2 | 76.7 |
| JAN | 60.1 | 86.9 | 88.6 | 79.2 | 85.4 | 86.7 | 80.4 | 59.4 | 89.6 | 82 | 60.7 | 89.9 | 79.1 |
| CDAN | 61.6 | 87.8 | 89.6 | 81.4 | 88.1 | 88.5 | 82.4 | 62.5 | 90.8 | 84.2 | 63.5 | 90.8 | 80.9 |
| MCD | 52.3 | 75.3 | 85.3 | 75.4 | 75.4 | 78.3 | 68.8 | 49.7 | 86 | 80.6 | 60 | 89 | 73.0 |
| AFN | 58.3 | 87.2 | 88.2 | 81.7 | 87 | 88.2 | 81 | 58.4 | 89.2 | 81.5 | 59.2 | 89.2 | 79.1 |
| MDD | 64 | 89.3 | 90.4 | 82.2 | 87.7 | 89.2 | 82.8 | 64.9 | 91.7 | 83.7 | 65.4 | 92 | 81.9 |
### VisDA-2017 accuracy ResNet-101
| Methods | Origin | Mean | plane | bcycl | bus | car | horse | knife | mcycl | person | plant | sktbrd | train | truck | Avg |
|-------------|--------|------|-------|-------|------|------|-------|-------|-------|--------|-------|--------|-------|-------|------|
| ERM | 52.4 | 51.7 | 63.6 | 35.3 | 50.6 | 78.2 | 74.6 | 18.7 | 82.1 | 16.0 | 84.2 | 35.5 | 77.4 | 4.7 | 56.9 |
| DANN | 57.4 | 79.5 | 93.5 | 74.3 | 83.4 | 50.7 | 87.2 | 90.2 | 89.9 | 76.1 | 88.1 | 91.4 | 89.7 | 39.8 | 74.9 |
| ADDA | / | 77.5 | 95.6 | 70.8 | 84.4 | 54.0 | 87.8 | 75.8 | 88.4 | 69.3 | 84.1 | 86.2 | 85.0 | 48.0 | 74.3 |
| BSP | 75.9 | 80.5 | 95.7 | 75.6 | 82.8 | 54.5 | 89.2 | 96.5 | 91.3 | 72.2 | 88.9 | 88.7 | 88.0 | 43.4 | 76.2 |
| DAN | 61.1 | 66.4 | 89.2 | 37.2 | 77.7 | 61.8 | 81.7 | 64.3 | 90.6 | 61.4 | 79.9 | 37.7 | 88.1 | 27.4 | 67.2 |
| JAN | / | 73.4 | 96.3 | 66.0 | 82.0 | 44.1 | 86.4 | 70.3 | 87.9 | 74.6 | 83.0 | 64.6 | 84.5 | 41.3 | 70.3 |
| CDAN | / | 80.1 | 94.0 | 69.2 | 78.9 | 57.0 | 89.8 | 94.9 | 91.9 | 80.3 | 86.8 | 84.9 | 85.0 | 48.5 | 76.5 |
| MCD | 71.9 | 77.7 | 87.8 | 75.7 | 84.2 | 78.1 | 91.6 | 95.3 | 88.1 | 78.3 | 83.4 | 64.5 | 84.8 | 20.9 | 76.7 |
| AFN | 76.1 | 75.0 | 95.6 | 56.2 | 81.3 | 69.8 | 93.0 | 81.0 | 93.4 | 74.1 | 91.7 | 55.0 | 90.6 | 18.1 | 74.4 |
| MDD | / | 82.0 | 88.3 | 62.8 | 85.2 | 69.9 | 91.9 | 95.1 | 94.4 | 81.2 | 93.8 | 89.8 | 84.1 | 47.9 | 79.8 |
| MCC | 78.8 | 83.6 | 95.3 | 85.8 | 77.1 | 68.0 | 93.9 | 92.9 | 84.5 | 79.5 | 93.6 | 93.7 | 85.3 | 53.8 | 80.4 |
| FixMatch | / | 79.5 | 96.5 | 76.6 | 72.6 | 84.6 | 96.3 | 92.6 | 90.5 | 81.8 | 91.9 | 74.6 | 87.3 | 8.6 | 78.4 |
### DomainNet accuracy on ResNet-101
| Methods | c->p | c->r | c->s | p->c | p->r | p->s | r->c | r->p | r->s | s->c | s->p | s->r | Avg |
|-------------|------|------|------|------|------|------|------|------|------|------|------|------|------|
| ERM | 32.7 | 50.6 | 39.4 | 41.1 | 56.8 | 35.0 | 48.6 | 48.8 | 36.1 | 49.0 | 34.8 | 46.1 | 43.3 |
| DAN | 38.8 | 55.2 | 43.9 | 45.9 | 59.0 | 40.8 | 50.8 | 49.8 | 38.9 | 56.1 | 45.9 | 55.5 | 48.4 |
| DANN | 37.9 | 54.3 | 44.4 | 41.7 | 55.6 | 36.8 | 50.7 | 50.8 | 40.1 | 55.0 | 45.0 | 54.5 | 47.2 |
| JAN | 40.5 | 56.7 | 45.1 | 47.2 | 59.9 | 43.0 | 54.2 | 52.6 | 41.9 | 56.6 | 46.2 | 55.5 | 50.0 |
| CDAN | 40.4 | 56.8 | 46.1 | 45.1 | 58.4 | 40.5 | 55.6 | 53.6 | 43.0 | 57.2 | 46.4 | 55.7 | 49.9 |
| MCD | 37.5 | 52.9 | 44.0 | 44.6 | 54.5 | 41.6 | 52.0 | 51.5 | 39.7 | 55.5 | 44.6 | 52.0 | 47.5 |
| MDD | 42.9 | 59.5 | 47.5 | 48.6 | 59.4 | 42.6 | 58.3 | 53.7 | 46.2 | 58.7 | 46.5 | 57.7 | 51.8 |
| MCC | 37.7 | 55.7 | 42.6 | 45.4 | 59.8 | 39.9 | 54.4 | 53.1 | 37.0 | 58.1 | 46.3 | 56.2 | 48.9 |
### DomainNet accuracy on ResNet-101 (Multi-Source)
| Methods | Origin | Avg | :c | :i | :p | :q | :r | :s |
|-------------|--------|------|------|------|------|------|------|------|
| ERM | 32.9 | 47.0 | 64.9 | 25.2 | 54.4 | 16.9 | 68.2 | 52.3 |
| MDD | / | 48.8 | 68.7 | 29.7 | 58.2 | 9.7 | 69.4 | 56.9 |
| Oracle | 63.0 | 69.1 | 78.2 | 40.7 | 71.6 | 69.7 | 83.8 | 70.6 |
### Performance on ImageNet-scale dataset
| | ResNet50, ImageNet->ImageNetR | ig_resnext101_32x8d, ImageNet->ImageSketch |
|------|-------------------------------|------------------------------------------|
| ERM | 35.6 | 54.9 |
| DAN | 39.8 | 55.7 |
| DANN | 52.7 | 56.5 |
| JAN | 41.7 | 55.7 |
| CDAN | 53.9 | 58.2 |
| MCD | 46.7 | 55.0 |
| AFN | 43.0 | 55.1 |
| MDD | 56.2 | 62.4 |
## Visualization
After training `DANN`, run the following command
```
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W --phase analysis
```
It may take a while, then in directory ``logs/dann/Office31_A2W/visualize``, you can find
``TSNE.png``.
Following are the t-SNE of representations from ResNet50 trained on source domain and those from DANN.
## TODO
1. Support self-training methods
2. Support translation methods
3. Add results on ViT
4. Add results on ImageNet
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{DANN,
author = {Ganin, Yaroslav and Lempitsky, Victor},
Booktitle = {ICML},
Title = {Unsupervised domain adaptation by backpropagation},
Year = {2015}
}
@inproceedings{DAN,
author = {Mingsheng Long and
Yue Cao and
Jianmin Wang and
Michael I. Jordan},
title = {Learning Transferable Features with Deep Adaptation Networks},
booktitle = {ICML},
year = {2015},
}
@inproceedings{JAN,
title={Deep transfer learning with joint adaptation networks},
author={Long, Mingsheng and Zhu, Han and Wang, Jianmin and Jordan, Michael I},
booktitle={ICML},
year={2017},
}
@inproceedings{ADDA,
title={Adversarial discriminative domain adaptation},
author={Tzeng, Eric and Hoffman, Judy and Saenko, Kate and Darrell, Trevor},
booktitle={CVPR},
year={2017}
}
@inproceedings{CDAN,
author = {Mingsheng Long and
Zhangjie Cao and
Jianmin Wang and
Michael I. Jordan},
title = {Conditional Adversarial Domain Adaptation},
booktitle = {NeurIPS},
year = {2018}
}
@inproceedings{MCD,
title={Maximum classifier discrepancy for unsupervised domain adaptation},
author={Saito, Kuniaki and Watanabe, Kohei and Ushiku, Yoshitaka and Harada, Tatsuya},
booktitle={CVPR},
year={2018}
}
@InProceedings{AFN,
author = {Xu, Ruijia and Li, Guanbin and Yang, Jihan and Lin, Liang},
title = {Larger Norm More Transferable: An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation},
booktitle = {ICCV},
year = {2019}
}
@inproceedings{MDD,
title={Bridging theory and algorithm for domain adaptation},
author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael},
booktitle={ICML},
year={2019},
}
@inproceedings{BSP,
title={Transferability vs. discriminability: Batch spectral penalization for adversarial domain adaptation},
author={Chen, Xinyang and Wang, Sinan and Long, Mingsheng and Wang, Jianmin},
booktitle={ICML},
year={2019},
}
@inproceedings{MCC,
author = {Ying Jin and
Ximei Wang and
Mingsheng Long and
Jianmin Wang},
title = {Less Confusion More Transferable: Minimum Class Confusion for Versatile
Domain Adaptation},
year={2020},
booktitle={ECCV},
}
@inproceedings{FixMatch,
title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence},
author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang},
booktitle={NIPS},
year={2020}
}
```
================================================
FILE: examples/domain_adaptation/image_classification/adda.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
Note: Our implementation is different from ADDA paper in several respects. We do not use separate networks for
source and target domain, nor fix classifier head. Besides, we do not adopt asymmetric objective loss function
of the feature extractor.
"""
import random
import time
import warnings
import copy
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.alignment.adda import ImageClassifier
from tllib.alignment.dann import DomainAdversarialLoss
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.modules.grl import WarmStartGradientReverseLayer
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def set_requires_grad(net, requires_grad=False):
"""
Set requies_grad=Fasle for all the networks to avoid unnecessary computations
"""
for param in net.parameters():
param.requires_grad = requires_grad
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
source_classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
if args.phase == 'train' and args.pretrain is None:
# first pretrain the classifier wish source data
print("Pretraining the model on source domain.")
args.pretrain = logger.get_checkpoint_path('pretrain')
pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
pretrain_optimizer = SGD(pretrain_model.get_parameters(), args.pretrain_lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
pretrain_lr_scheduler = LambdaLR(pretrain_optimizer,
lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** (
-args.lr_decay))
# start pretraining
for epoch in range(args.pretrain_epochs):
print("lr:", pretrain_lr_scheduler.get_lr())
# pretrain for one epoch
utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer,
pretrain_lr_scheduler, epoch, args,
device)
# validate to show pretrain process
utils.validate(val_loader, pretrain_model, args, device)
torch.save(pretrain_model.state_dict(), args.pretrain)
print("Pretraining process is done.")
checkpoint = torch.load(args.pretrain, map_location='cpu')
source_classifier.load_state_dict(checkpoint)
target_classifier = copy.deepcopy(source_classifier)
# freeze source classifier
set_requires_grad(source_classifier, False)
source_classifier.freeze_bn()
domain_discri = DomainDiscriminator(in_feature=source_classifier.features_dim, hidden_size=1024).to(device)
# define loss function
grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=2., max_iters=1000, auto_step=True)
domain_adv = DomainAdversarialLoss(domain_discri, grl=grl).to(device)
# define optimizer and lr scheduler
# note that we only optimize target feature extractor
optimizer = SGD(target_classifier.get_parameters(optimize_head=False) + domain_discri.get_parameters(), args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
target_classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(target_classifier.backbone, target_classifier.pool_layer,
target_classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, target_classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, source_classifier, target_classifier, domain_adv,
optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, target_classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(target_classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
target_classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, target_classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
source_model: ImageClassifier, target_model: ImageClassifier, domain_adv: DomainAdversarialLoss,
optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':5.2f')
data_time = AverageMeter('Data', ':5.2f')
losses_transfer = AverageMeter('Transfer Loss', ':6.2f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_transfer, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
target_model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, = next(train_source_iter)[:1]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
_, f_s = source_model(x_s)
_, f_t = target_model(x_t)
loss_transfer = domain_adv(f_s, f_t)
# Compute gradient and do SGD step
optimizer.zero_grad()
loss_transfer.backward()
optimizer.step()
lr_scheduler.step()
losses_transfer.update(loss_transfer.item(), x_s.size(0))
domain_acc = domain_adv.domain_discriminator_accuracy
domain_accs.update(domain_acc.item(), x_s.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ADDA for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--pretrain', type=str, default=None,
help='pretrain checkpoint for classification model')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate of the classifier', dest='lr')
parser.add_argument('--pretrain-lr', default=0.001, type=float, help='initial pretrain learning rate')
parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--pretrain-epochs', default=3, type=int, metavar='N',
help='number of total epochs (pretrain) to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='adda',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/adda.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python adda.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/adda/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python adda.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python adda.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 --lr-d 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/afn.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.normalization.afn import AdaptiveFeatureNorm, ImageClassifier
from tllib.modules.entropy import entropy
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, args.num_blocks,
bottleneck_dim=args.bottleneck_dim, dropout_p=args.dropout_p,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
adaptive_feature_norm = AdaptiveFeatureNorm(args.delta).to(device)
# define optimizer
# the learning rate is fixed according to origin paper
optimizer = SGD(classifier.get_parameters(), args.lr, weight_decay=args.weight_decay)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, adaptive_feature_norm, optimizer, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
norm_losses = AverageMeter('Norm Loss', ':3.2f')
src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')
tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
y_t, f_t = model(x_t)
# classification loss
cls_loss = F.cross_entropy(y_s, labels_s)
# norm loss
norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)
loss = cls_loss + norm_loss * args.trade_off_norm
# using entropy minimization
if args.trade_off_entropy:
y_t = F.softmax(y_t, dim=1)
entropy_loss = entropy(y_t, reduction='mean')
loss += entropy_loss * args.trade_off_entropy
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update statistics
cls_acc = accuracy(y_s, labels_s)[0]
cls_losses.update(cls_loss.item(), x_s.size(0))
norm_losses.update(norm_loss.item(), x_s.size(0))
src_feature_norm.update(f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))
tgt_feature_norm.update(f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='AFN for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='ran.crop')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('-n', '--num-blocks', default=1, type=int, help='Number of basic blocks for classifier')
parser.add_argument('--bottleneck-dim', default=1000, type=int, help='Dimension of bottleneck')
parser.add_argument('--dropout-p', default=0.5, type=float,
help='Dropout probability')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)',
dest='weight_decay')
parser.add_argument('--trade-off-norm', default=0.05, type=float,
help='the trade-off hyper-parameter for norm loss')
parser.add_argument('--trade-off-entropy', default=None, type=float,
help='the trade-off hyper-parameter for entropy loss')
parser.add_argument('-r', '--delta', default=1, type=float, help='Increment for L2 norm')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='afn',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/afn.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python afn.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 -r 0.3 -b 36 \
--epochs 10 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/afn/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python afn.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 20 -i 2500 --seed 0 --log logs/afn/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python afn.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 20 -i 2500 --seed 0 --log logs/afn_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool -r 0.3 --lr 0.01 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/afn/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool -r 0.1 --lr 0.03 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/afn/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool -r 0.1 --lr 0.1 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/afn/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/bsp.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.alignment.dann import DomainAdversarialLoss
from tllib.alignment.bsp import BatchSpectralPenalizationLoss, ImageClassifier
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv = DomainAdversarialLoss(domain_discri).to(device)
bsp_penalty = BatchSpectralPenalizationLoss().to(device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
if args.pretrain is None:
# first pretrain the classifier wish source data
print("Pretraining the model on source domain.")
args.pretrain = logger.get_checkpoint_path('pretrain')
pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
pretrain_optimizer = SGD(pretrain_model.get_parameters(), args.pretrain_lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
pretrain_lr_scheduler = LambdaLR(pretrain_optimizer,
lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** (
-args.lr_decay))
# start pretraining
for epoch in range(args.pretrain_epochs):
print("lr:", pretrain_lr_scheduler.get_lr())
# pretrain for one epoch
utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer,
pretrain_lr_scheduler, epoch, args,
device)
# validate to show pretrain process
utils.validate(val_loader, pretrain_model, args, device)
torch.save(pretrain_model.state_dict(), args.pretrain)
print("Pretraining process is done.")
checkpoint = torch.load(args.pretrain, map_location='cpu')
classifier.load_state_dict(checkpoint)
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr:", lr_scheduler.get_last_lr()[0])
# train for one epoch
train(train_source_iter, train_target_iter, classifier, domain_adv, bsp_penalty, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: ImageClassifier, domain_adv: DomainAdversarialLoss, bsp_penalty: BatchSpectralPenalizationLoss,
optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':5.2f')
data_time = AverageMeter('Data', ':5.2f')
losses = AverageMeter('Loss', ':6.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = domain_adv(f_s, f_t)
bsp_loss = bsp_penalty(f_s, f_t)
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off + bsp_loss * args.trade_off_bsp
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
domain_accs.update(domain_acc.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='BSP for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--pretrain', type=str, default=None,
help='pretrain checkpoint for classification model')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
parser.add_argument('--trade-off-bsp', default=2e-4, type=float,
help='the trade-off hyper-parameter for bsp loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--pretrain-lr', default=0.001, type=float, help='initial pretrain learning rate')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--pretrain-epochs', default=3, type=int, metavar='N',
help='number of total epochs(pretrain) to run (default: 3)')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='bsp',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/bsp.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python bsp.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/bsp/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python bsp.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python bsp.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.1 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/SVHN2MNIST --pretrain-epochs 1
================================================
FILE: examples/domain_adaptation/image_classification/cc_loss.py
================================================
"""
@author: Ying Jin
@contact: sherryying003@gmail.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.self_training.mcc import MinimumClassConfusionLoss, ImageClassifier
from tllib.self_training.cc_loss import CCConsistency
from tllib.vision.transforms import MultipleApply
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std,
auto_augment=args.auto_augment)
train_target_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_source_transform: ", train_source_transform)
print("train_target_transform: ", train_target_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform,
train_target_transform=train_target_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr:", lr_scheduler.get_last_lr()[0])
# train for one epoch
train(train_source_iter, train_target_iter, classifier, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: ImageClassifier, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# define loss function
mcc = MinimumClassConfusionLoss(temperature=args.temperature)
consistency = CCConsistency(temperature=args.temperature, thr=args.thr)
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
(x_t, x_t_strong), labels_t = next(train_target_iter)[:2]
x_s = x_s.to(device)
x_t = x_t.to(device)
x_t_strong = x_t_strong.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t, x_t_strong), dim=0)
y, f = model(x)
y_s, y_t, y_t_strong = y.chunk(3, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
mcc_loss = mcc(y_t)
consistency_loss, selec_ratio = consistency(y_t, y_t_strong)
loss = cls_loss + mcc_loss * args.trade_off + consistency_loss * args.trade_off_consistency
transfer_loss = mcc_loss * args.trade_off + consistency_loss * args.trade_off_consistency
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CC Loss for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--temperature', default=2.5, type=float, help='parameter temperature scaling')
parser.add_argument('--thr', default=0.95, type=float, help='thr parameter for consistency loss')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for original mcc loss')
parser.add_argument('--trade_off_consistency', default=1., type=float,
help='the trade-off hyper-parameter for consistency loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.005, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='mcc',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/cc_loss.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=5 python cc_loss.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 30 --seed 0 --lr 0.002 --per-class-eval --temperature 3.0 --train-resizing cen.crop --log logs/cc_loss/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 --seed 0 --temperature 2.5 --bottleneck-dim 2048 --log logs/cc_loss/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/cc_loss_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/cdan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
classifier_feature_dim = classifier.features_dim
if args.randomized:
domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024).to(device)
else:
domain_discri = DomainDiscriminator(classifier_feature_dim * num_classes, hidden_size=1024).to(device)
all_parameters = classifier.get_parameters() + domain_discri.get_parameters()
# define optimizer and lr scheduler
optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv = ConditionalDomainAdversarialLoss(
domain_discri, entropy_conditioning=args.entropy,
num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized,
randomized_dim=args.randomized_dim
).to(device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr:", lr_scheduler.get_last_lr()[0])
# train for one epoch
train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
domain_adv: ConditionalDomainAdversarialLoss, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = domain_adv(y_s, f_s, y_t, f_t)
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc, x_s.size(0))
domain_accs.update(domain_acc, x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CDAN for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('-r', '--randomized', action='store_true',
help='using randomized multi-linear-map (default: False)')
parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,
help='randomized dimension when using randomized multi-linear-map (default: 1024)')
parser.add_argument('--entropy', default=False, action='store_true', help='use entropy conditioning')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='cdan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/cdan.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python cdan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/cdan/VisDA2017
# ResNet101, DomainNet, Single Source
# Use randomized multi-linear-map to decrease GPU memory usage
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python cdan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python cdan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --log logs/cdan_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/dan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.alignment.dan import MultipleKernelMaximumMeanDiscrepancy, ImageClassifier
from tllib.modules.kernels import GaussianKernel
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
mkmmd_loss = MultipleKernelMaximumMeanDiscrepancy(
kernels=[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],
linear=not args.non_linear
)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, mkmmd_loss, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
mkmmd_loss: MultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':5.4f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
mkmmd_loss.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
y_t, f_t = model(x_t)
cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = mkmmd_loss(f_s, f_t)
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DAN for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--non-linear', default=False, action='store_true',
help='whether not use the linear version')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='dan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/dan.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_W2A
CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_W2D
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python dan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 20 -i 500 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/dan/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python dan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python dan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --log logs/dan_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/dann.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv = DomainAdversarialLoss(domain_discri).to(device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr:", lr_scheduler.get_last_lr()[0])
# train for one epoch
train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':5.2f')
data_time = AverageMeter('Data', ':5.2f')
losses = AverageMeter('Loss', ':6.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = domain_adv(f_s, f_t)
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
domain_accs.update(domain_acc.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DANN for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='dann',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/dann.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/dann/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python dann.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/erm.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.modules.classifier import Classifier
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
from tllib.utils.data import ForeverDataIterator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=not args.scratch).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
utils.empirical_risk_minimization(train_source_iter, classifier, optimizer, lr_scheduler, epoch, args, device)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Source Only for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='src_only',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/erm.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 20 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/erm/VisDA2017
# ResNet101, DomainNet, Oracle
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i -t i -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s q -t q -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_s
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 20 -i 2500 --seed 0 --log logs/erm/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python erm.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 20 -i 2500 --seed 0 --log logs/erm_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/fixmatch.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.classifier import Classifier
from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss
from tllib.vision.transforms import MultipleApply
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ImageClassifier(Classifier):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
def forward(self, x: torch.Tensor):
""""""
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
predictions = self.head(f)
return predictions
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std,
auto_augment=args.auto_augment)
train_target_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_source_transform: ", train_source_transform)
print("train_target_transform: ", train_target_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform,
train_target_transform=train_target_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.unlabeled_batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
print(classifier)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr:", lr_scheduler.get_last_lr())
# train for one epoch
train(train_source_iter, train_target_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: ImageClassifier, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':5.2f')
data_time = AverageMeter('Data', ':5.2f')
cls_losses = AverageMeter('Cls Loss', ':6.2f')
self_training_losses = AverageMeter('Self Training Loss', ':6.2f')
losses = AverageMeter('Loss', ':6.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')
pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,
pseudo_label_ratios],
prefix="Epoch: [{}]".format(epoch))
self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
(x_t, x_t_strong), labels_t = next(train_target_iter)[:2]
x_s = x_s.to(device)
x_t = x_t.to(device)
x_t_strong = x_t_strong.to(device)
labels_s = labels_s.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
with torch.no_grad():
y_t = model(x_t)
# cross entropy loss
y_s = model(x_s)
cls_loss = F.cross_entropy(y_s, labels_s)
cls_loss.backward()
# self-training loss
y_t_strong = model(x_t_strong)
self_training_loss, mask, pseudo_labels = self_training_criterion(y_t_strong, y_t)
self_training_loss = args.trade_off * self_training_loss
self_training_loss.backward()
# measure accuracy and record loss
loss = cls_loss + self_training_loss
losses.update(loss.item(), x_s.size(0))
cls_losses.update(cls_loss.item(), x_s.size(0))
self_training_losses.update(self_training_loss.item(), x_s.size(0))
cls_acc = accuracy(y_s, labels_s)[0]
cls_accs.update(cls_acc.item(), x_s.size(0))
# ratio of pseudo labels
n_pseudo_labels = mask.sum()
ratio = n_pseudo_labels / x_t.size(0)
pseudo_label_ratios.update(ratio.item() * 100, x_t.size(0))
# accuracy of pseudo labels
if n_pseudo_labels > 0:
pseudo_labels = pseudo_labels * mask - (1 - mask)
n_correct = (pseudo_labels == labels_t).float().sum()
pseudo_label_acc = n_correct / n_pseudo_labels * 100
pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='FixMatch for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.5 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('-ub', '--unlabeled-batch-size', default=32, type=int,
help='mini-batch size of unlabeled data (target domain) (default: 32)')
parser.add_argument('--threshold', default=0.9, type=float,
help='confidence threshold')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='fixmatch',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/fixmatch.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s A -t W -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s D -t W -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s W -t D -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s A -t D -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s D -t A -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s W -t A -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 --train-resizing cen.crop \
--lr 0.003 --threshold 0.8 --bottleneck-dim 2048 --epochs 20 -ub 64 --seed 0 --per-class-eval --log logs/fixmatch/VisDA2017
================================================
FILE: examples/domain_adaptation/image_classification/jan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.alignment.jan import JointMultipleKernelMaximumMeanDiscrepancy, ImageClassifier, Theta
from tllib.modules.kernels import GaussianKernel
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
# define loss function
if args.adversarial:
thetas = [Theta(dim).to(device) for dim in (classifier.features_dim, num_classes)]
else:
thetas = None
jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy(
kernels=(
[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],
(GaussianKernel(sigma=0.92, track_running_stats=False),)
),
linear=args.linear, thetas=thetas
).to(device)
parameters = classifier.get_parameters()
if thetas is not None:
parameters += [{"params": theta.parameters(), 'lr': 0.1} for theta in thetas]
# define optimizer
optimizer = SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, jmmd_loss, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
jmmd_loss: JointMultipleKernelMaximumMeanDiscrepancy, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':5.4f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
jmmd_loss.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = jmmd_loss(
(f_s, F.softmax(y_s, dim=1)),
(f_t, F.softmax(y_t, dim=1))
)
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='JAN for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--linear', default=False, action='store_true',
help='whether use the linear version')
parser.add_argument('--adversarial', default=False, action='store_true',
help='whether use adversarial theta')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='jan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/jan.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_W2A
CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_W2D
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python jan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 20 -i 500 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/jan/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python jan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python jan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/jan_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python jan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python jan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/USPS2MNIST
CUDA_VISIBLE_DEVICES=4 python jan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/mcc.py
================================================
"""
@author: Ying Jin
@contact: sherryying003@gmail.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.self_training.mcc import MinimumClassConfusionLoss, ImageClassifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
mcc_loss = MinimumClassConfusionLoss(temperature=args.temperature)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr:", lr_scheduler.get_last_lr()[0])
# train for one epoch
train(train_source_iter, train_target_iter, classifier, mcc_loss, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: ImageClassifier, mcc: MinimumClassConfusionLoss, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = mcc(y_t)
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MCC for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--temperature', default=2.5, type=float, help='parameter temperature scaling')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.005, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='mcc',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/mcc.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=5 python mcc.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 30 --seed 0 --lr 0.002 --per-class-eval --temperature 3.0 --train-resizing cen.crop --log logs/mcc/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python mcc.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 --seed 0 --temperature 2.5 --bottleneck-dim 2048 --log logs/mcc/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python mcc.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/mcc_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/mcd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
from typing import Tuple
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
import torch.utils.data
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.alignment.mcd import ImageClassifierHead, entropy, classifier_discrepancy
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy, ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
G = utils.get_model(args.arch, pretrain=not args.scratch).to(device) # feature extractor
# two image classifier heads
pool_layer = nn.Identity() if args.no_pool else None
F1 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim, pool_layer).to(device)
F2 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim, pool_layer).to(device)
# define optimizer
# the learning rate is fixed according to origin paper
optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005)
optimizer_f = SGD([
{"params": F1.parameters()},
{"params": F2.parameters()},
], momentum=0.9, lr=args.lr, weight_decay=0.0005)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
G.load_state_dict(checkpoint['G'])
F1.load_state_dict(checkpoint['F1'])
F2.load_state_dict(checkpoint['F2'])
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(G, F1.pool_layer).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = validate(test_loader, G, F1, F2, args)
print(acc1)
return
# start training
best_acc1 = 0.
best_results = None
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g, optimizer_f, epoch, args)
# evaluate on validation set
results = validate(val_loader, G, F1, F2, args)
# remember best acc@1 and save checkpoint
torch.save({
'G': G.state_dict(),
'F1': F1.state_dict(),
'F2': F2.state_dict()
}, logger.get_checkpoint_path('latest'))
if max(results) > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(results)
best_results = results
print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results))
# evaluate on test set
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
G.load_state_dict(checkpoint['G'])
F1.load_state_dict(checkpoint['F1'])
F2.load_state_dict(checkpoint['F2'])
results = validate(test_loader, G, F1, F2, args)
print("test_acc1 = {:3.1f}".format(max(results)))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead,
optimizer_g: SGD, optimizer_f: SGD, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
G.train()
F1.train()
F2.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
x = torch.cat((x_s, x_t), dim=0)
assert x.requires_grad is False
# measure data loading time
data_time.update(time.time() - end)
# Step A train all networks to minimize loss on source domain
optimizer_g.zero_grad()
optimizer_f.zero_grad()
g = G(x)
y_1 = F1(g)
y_2 = F2(g)
y1_s, y1_t = y_1.chunk(2, dim=0)
y2_s, y2_t = y_2.chunk(2, dim=0)
y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \
(entropy(y1_t) + entropy(y2_t)) * args.trade_off_entropy
loss.backward()
optimizer_g.step()
optimizer_f.step()
# Step B train classifier to maximize discrepancy
optimizer_g.zero_grad()
optimizer_f.zero_grad()
g = G(x)
y_1 = F1(g)
y_2 = F2(g)
y1_s, y1_t = y_1.chunk(2, dim=0)
y2_s, y2_t = y_2.chunk(2, dim=0)
y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \
(entropy(y1_t) + entropy(y2_t)) * args.trade_off_entropy - \
classifier_discrepancy(y1_t, y2_t) * args.trade_off
loss.backward()
optimizer_f.step()
# Step C train genrator to minimize discrepancy
for k in range(args.num_k):
optimizer_g.zero_grad()
g = G(x)
y_1 = F1(g)
y_2 = F2(g)
y1_s, y1_t = y_1.chunk(2, dim=0)
y2_s, y2_t = y_2.chunk(2, dim=0)
y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1)
mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off
mcd_loss.backward()
optimizer_g.step()
cls_acc = accuracy(y1_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
trans_losses.update(mcd_loss.item(), x_s.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
def validate(val_loader: DataLoader, G: nn.Module, F1: ImageClassifierHead,
F2: ImageClassifierHead, args: argparse.Namespace) -> Tuple[float, float]:
batch_time = AverageMeter('Time', ':6.3f')
top1_1 = AverageMeter('Acc_1', ':6.2f')
top1_2 = AverageMeter('Acc_2', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, top1_1, top1_2],
prefix='Test: ')
# switch to evaluate mode
G.eval()
F1.eval()
F2.eval()
if args.per_class_eval:
confmat = ConfusionMatrix(len(args.class_names))
else:
confmat = None
with torch.no_grad():
end = time.time()
for i, data in enumerate(val_loader):
images, target = data[:2]
images = images.to(device)
target = target.to(device)
# compute output
g = G(images)
y1, y2 = F1(g), F2(g)
# measure accuracy and record loss
acc1, = accuracy(y1, target)
acc2, = accuracy(y2, target)
if confmat:
confmat.update(target, y1.argmax(1))
top1_1.update(acc1.item(), images.size(0))
top1_2.update(acc2.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}'
.format(top1_1=top1_1, top1_2=top1_2))
if confmat:
print(confmat.format(args.class_names))
return top1_1.avg, top1_2.avg
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MCD for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true',
help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+',
default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+',
default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=1024, type=int)
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
parser.add_argument('--trade-off-entropy', default=0.01, type=float,
help='the trade-off hyper-parameter for entropy loss')
parser.add_argument('--num-k', type=int, default=4, metavar='K',
help='how many steps to repeat the generator update')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='mcd',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/mcd.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
# We found MCD loss is sensitive to class number,
# thus, when the class number increase, please increase trade-off correspondingly.
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_W2A
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_W2D
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python mcd.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 20 --center-crop --seed 0 -i 500 --per-class-eval --train-resizing cen.crop --log logs/mcd/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t i -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2i
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python mcd.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 100.0 --log logs/mcd/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python mcd.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --trade-off 500.0 --log logs/mcd_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Pr
# Digits
CUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/mdd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import os.path as osp
import shutil
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.alignment.mdd import ClassificationMarginDisparityDiscrepancy \
as MarginDisparityDiscrepancy, ImageClassifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
random_horizontal_flip=not args.no_hflip,
random_color_jitter=False, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
norm_mean=args.norm_mean, norm_std=args.norm_std)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
width=args.bottleneck_dim, pool_layer=pool_layer).to(device)
mdd = MarginDisparityDiscrepancy(args.margin).to(device)
# define optimizer and lr_scheduler
# The learning rate of the classifiers are set 10 times to that of the feature extractor by default.
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, classifier, mdd, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
classifier: ImageClassifier, mdd: MarginDisparityDiscrepancy, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
classifier.train()
mdd.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, labels_s = next(train_source_iter)[:2]
x_t, = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
outputs, outputs_adv = classifier(x)
y_s, y_t = outputs.chunk(2, dim=0)
y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)
# compute cross entropy loss on source domain
cls_loss = F.cross_entropy(y_s, labels_s)
# compute margin disparity discrepancy between domains
# for adversarial classifier, minimize negative mdd is equal to maximize mdd
transfer_loss = -mdd(y_s, y_s_adv, y_t, y_t_adv)
loss = cls_loss + transfer_loss * args.trade_off
classifier.step()
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
# compute gradient and do SGD step
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='MDD for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
parser.add_argument('--resize-size', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean')
parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=1024, type=int)
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--margin', type=float, default=4., help="margin gamma")
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.004, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0002, type=float)
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='mdd',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/mdd.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python mdd.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 --epochs 30 \
--bottleneck-dim 1024 --seed 0 --train-resizing cen.crop --per-class-eval -b 36 --log logs/mdd/VisDA2017
# ResNet101, DomainNet, Single Source
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2p
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2r
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2s
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2c
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2r
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2s
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2c
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2p
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2s
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2c
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2p
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2r
# ResNet50, ImageNet200 -> ImageNetR
CUDA_VISIBLE_DEVICES=0 python mdd.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 40 -i 2500 -p 500 \
--bottleneck-dim 2048 --seed 0 --lr 0.004 --train-resizing cen.crop --log logs/mdd/ImageNet_IN2INR
# ig_resnext101_32x8d, ImageNet -> ImageNetSketch
CUDA_VISIBLE_DEVICES=0 python mdd.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d \
--epochs 40 -i 2500 -p 500 --bottleneck-dim 2048 --margin 2. --seed 0 --lr 0.004 --train-resizing cen.crop \
--log logs/mdd_ig_resnext101_32x8d/ImageNet_IN2sketch
# Vision Transformer, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Pr
# ResNet50, Office-Home, Multi Source
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Ar
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Cl
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Pr
CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Rw
# ResNet101, DomainNet, Multi Source
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2c
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2i
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2p
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2q
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2r
CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2s
# Digits
CUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/MNIST2USPS
CUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/USPS2MNIST
CUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \
--resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/SVHN2MNIST
================================================
FILE: examples/domain_adaptation/image_classification/requirements.txt
================================================
timm
================================================
FILE: examples/domain_adaptation/image_classification/self_ensemble.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import utils
from tllib.self_training.pi_model import L2ConsistencyLoss
from tllib.self_training.mean_teacher import EMATeacher
from tllib.self_training.self_ensemble import ClassBalanceLoss, ImageClassifier
from tllib.vision.transforms import ResizeImage, MultipleApply
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
# we find self ensemble is sensitive to data augmentation. The following
# data augmentation performs well for evaluated datasets
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = T.Compose([
ResizeImage(256),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5),
T.RandomGrayscale(),
T.ToTensor(),
normalize
])
val_transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target,
train_transform, val_transform, MultipleApply([train_transform, val_transform]))
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
# define optimizer and lr scheduler
optimizer = Adam(classifier.get_parameters(), args.lr)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
if args.pretrain is None:
# first pretrain the classifier wish source data
print("Pretraining the model on source domain.")
args.pretrain = logger.get_checkpoint_path('pretrain')
pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
pretrain_optimizer = Adam(pretrain_model.get_parameters(), args.pretrain_lr)
pretrain_lr_scheduler = LambdaLR(pretrain_optimizer,
lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** (
-args.lr_decay))
# start pretraining
for epoch in range(args.pretrain_epochs):
# pretrain for one epoch
utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer,
pretrain_lr_scheduler, epoch, args,
device)
# validate to show pretrain process
utils.validate(val_loader, pretrain_model, args, device)
torch.save(pretrain_model.state_dict(), args.pretrain)
print("Pretraining process is done.")
checkpoint = torch.load(args.pretrain, map_location='cpu')
classifier.load_state_dict(checkpoint)
teacher = EMATeacher(classifier, alpha=args.alpha)
consistency_loss = L2ConsistencyLoss().to(device)
class_balance_loss = ClassBalanceLoss(num_classes).to(device)
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, classifier, teacher, consistency_loss, class_balance_loss,
optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
teacher: EMATeacher, consistency_loss, class_balance_loss,
optimizer: Adam, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
cons_losses = AverageMeter('Cons Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, cls_losses, cons_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
teacher.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
(x_t1, x_t2), = next(train_target_iter)[:1]
x_s = x_s.to(device)
x_t1 = x_t1.to(device)
x_t2 = x_t2.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, _ = model(x_s)
y_t, _ = model(x_t1)
y_t_teacher, _ = teacher(x_t2)
# classification loss
cls_loss = F.cross_entropy(y_s, labels_s)
# compute output and mask
p_t = F.softmax(y_t, dim=1)
p_t_teacher = F.softmax(y_t_teacher, dim=1)
confidence, _ = p_t_teacher.max(dim=1)
mask = (confidence > args.threshold).float()
# consistency loss
cons_loss = consistency_loss(p_t, p_t_teacher, mask)
# balance loss
balance_loss = class_balance_loss(p_t) * mask.mean()
loss = cls_loss + args.trade_off_cons * cons_loss + args.trade_off_balance * balance_loss
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# update teacher
teacher.update()
# update statistics
cls_acc = accuracy(y_s, labels_s)[0]
cls_losses.update(cls_loss.item(), x_s.size(0))
cons_losses.update(cons_loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Self Ensemble for Unsupervised Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--pretrain', type=str, default=None,
help='pretrain checkpoint for classification model')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--pretrain-lr', '--pretrain-learning-rate', default=3e-5, type=float,
help='initial pretrain learning rate', dest='pretrain_lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--alpha', default=0.99, type=float, help='ema decay rate (default: 0.99)')
parser.add_argument('--threshold', default=0.8, type=float, help='confidence threshold')
parser.add_argument('--trade-off-cons', default=3, type=float, help='trade off parameter for consistency loss')
parser.add_argument('--trade-off-balance', default=0.01, type=float,
help='trade off parameter for class balance loss')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--pretrain-epochs', default=0, type=int, metavar='N',
help='number of total epochs(pretrain) to run')
parser.add_argument('--epochs', default=10, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='self_ensemble',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_classification/self_ensemble.sh
================================================
#!/usr/bin/env bash
# ResNet50, Office31, Single Source
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s A -t W -a resnet50 --seed 1 --log logs/self_ensemble/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s D -t W -a resnet50 --seed 1 --log logs/self_ensemble/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s W -t D -a resnet50 --seed 1 --log logs/self_ensemble/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s A -t D -a resnet50 --seed 1 --log logs/self_ensemble/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s D -t A -a resnet50 --seed 1 --log logs/self_ensemble/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s W -t A -a resnet50 --seed 1 --log logs/self_ensemble/Office31_W2A
# ResNet50, Office-Home, Single Source
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Pr
# ResNet101, VisDA-2017, Single Source
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \
--epochs 20 --seed 0 --per-class-eval --log logs/self_ensemble/VisDA2017 --lr-gamma 0.0002 -b 32
# Office-Home on Vision Transformer
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Pr
================================================
FILE: examples/domain_adaptation/image_classification/utils.py
================================================
"""
@author: Junguang Jiang, Baixu Chen
@contact: JiangJunguang1123@outlook.com, cbx_99_hasta@outlook.com
"""
import sys
import os.path as osp
import time
from PIL import Image
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from timm.data.auto_augment import auto_augment_transform, rand_augment_transform
sys.path.append('../../..')
import tllib.vision.datasets as datasets
import tllib.vision.models as models
from tllib.vision.transforms import ResizeImage
from tllib.utils.metric import accuracy, ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.vision.datasets.imagelist import MultipleDomainsDataset
def get_model_names():
return sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
) + timm.list_models()
def get_model(model_name, pretrain=True):
if model_name in models.__dict__:
# load models from tllib.vision.models
backbone = models.__dict__[model_name](pretrained=pretrain)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=pretrain)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
return backbone
def get_dataset_names():
return sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
) + ['Digits']
def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None):
if train_target_transform is None:
train_target_transform = train_source_transform
if dataset_name == "Digits":
train_source_dataset = datasets.__dict__[source[0]](osp.join(root, source[0]), download=True,
transform=train_source_transform)
train_target_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), download=True,
transform=train_target_transform)
val_dataset = test_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), split='test',
download=True, transform=val_transform)
class_names = datasets.MNIST.get_classes()
num_classes = len(class_names)
elif dataset_name in datasets.__dict__:
# load datasets from tllib.vision.datasets
dataset = datasets.__dict__[dataset_name]
def concat_dataset(tasks, start_idx, **kwargs):
# return ConcatDataset([dataset(task=task, **kwargs) for task in tasks])
return MultipleDomainsDataset([dataset(task=task, **kwargs) for task in tasks], tasks,
domain_ids=list(range(start_idx, start_idx + len(tasks))))
train_source_dataset = concat_dataset(root=root, tasks=source, download=True, transform=train_source_transform,
start_idx=0)
train_target_dataset = concat_dataset(root=root, tasks=target, download=True, transform=train_target_transform,
start_idx=len(source))
val_dataset = concat_dataset(root=root, tasks=target, download=True, transform=val_transform,
start_idx=len(source))
if dataset_name == 'DomainNet':
test_dataset = concat_dataset(root=root, tasks=target, split='test', download=True, transform=val_transform,
start_idx=len(source))
else:
test_dataset = val_dataset
class_names = train_source_dataset.datasets[0].classes
num_classes = len(class_names)
else:
raise NotImplementedError(dataset_name)
return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names
def validate(val_loader, model, args, device) -> float:
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1],
prefix='Test: ')
# switch to evaluate mode
model.eval()
if args.per_class_eval:
confmat = ConfusionMatrix(len(args.class_names))
else:
confmat = None
with torch.no_grad():
end = time.time()
for i, data in enumerate(val_loader):
images, target = data[:2]
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = F.cross_entropy(output, target)
# measure accuracy and record loss
acc1, = accuracy(output, target, topk=(1,))
if confmat:
confmat.update(target, output.argmax(1))
losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
if confmat:
print(confmat.format(args.class_names))
return top1.avg
def get_train_transform(resizing='default', scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), random_horizontal_flip=True,
random_color_jitter=False, resize_size=224, norm_mean=(0.485, 0.456, 0.406),
norm_std=(0.229, 0.224, 0.225), auto_augment=None):
"""
resizing mode:
- default: resize the image to 256 and take a random resized crop of size 224;
- cen.crop: resize the image to 256 and take the center crop of size 224;
- res: resize the image to 224;
"""
transformed_img_size = 224
if resizing == 'default':
transform = T.Compose([
ResizeImage(256),
T.RandomResizedCrop(224, scale=scale, ratio=ratio)
])
elif resizing == 'cen.crop':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224)
])
elif resizing == 'ran.crop':
transform = T.Compose([
ResizeImage(256),
T.RandomCrop(224)
])
elif resizing == 'res.':
transform = ResizeImage(resize_size)
transformed_img_size = resize_size
else:
raise NotImplementedError(resizing)
transforms = [transform]
if random_horizontal_flip:
transforms.append(T.RandomHorizontalFlip())
if auto_augment:
aa_params = dict(
translate_const=int(transformed_img_size * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in norm_mean]),
interpolation=Image.BILINEAR
)
if auto_augment.startswith('rand'):
transforms.append(rand_augment_transform(auto_augment, aa_params))
else:
transforms.append(auto_augment_transform(auto_augment, aa_params))
elif random_color_jitter:
transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))
transforms.extend([
T.ToTensor(),
T.Normalize(mean=norm_mean, std=norm_std)
])
return T.Compose(transforms)
def get_val_transform(resizing='default', resize_size=224,
norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
"""
resizing mode:
- default: resize the image to 256 and take the center crop of size 224;
– res.: resize the image to 224
"""
if resizing == 'default':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224),
])
elif resizing == 'res.':
transform = ResizeImage(resize_size)
else:
raise NotImplementedError(resizing)
return T.Compose([
transform,
T.ToTensor(),
T.Normalize(mean=norm_mean, std=norm_std)
])
def empirical_risk_minimization(train_source_iter, model, optimizer, lr_scheduler, epoch, args, device):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)[:2]
x_s = x_s.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
cls_loss = F.cross_entropy(y_s, labels_s)
loss = cls_loss
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
================================================
FILE: examples/domain_adaptation/image_regression/README.md
================================================
# Unsupervised Domain Adaptation for Image Regression Tasks
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to better reproduce the benchmark results.
## Dataset
Following datasets can be downloaded automatically:
- [DSprites](https://github.com/deepmind/dsprites-dataset)
- [MPI3D](https://github.com/rr-learning/disentanglement_dataset)
## Supported Methods
Supported methods include:
- [Disparity Discrepancy (DD)](https://arxiv.org/abs/1904.05801)
- [Representation Subspace Distance (RSD)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Representation-Subspace-Distance-for-Domain-Adaptation-Regression-icml21.pdf)
## Experiment and Results
The shell files give the script to reproduce the benchmark results with specified hyper-parameters.
For example, if you want to train DD on DSprites, use the following script
```shell script
# Train a DD on DSprites C->N task using ResNet 18.
# Assume you have put the datasets under the path `data/dSprites`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/mdd/dSprites_C2N --wd 0.0005
```
**Notations**
- ``Origin`` means the accuracy reported by the original paper.
- ``Avg`` is the accuracy reported by Transfer-Learn.
- ``ERM`` refers to the model trained with data from the source domain.
- ``Oracle`` refers to the model trained with data from the target domain.
Labels are all normalized to [0, 1] to eliminate the effects of diverse scale in regression values.
We repeat experiments on DD for three times and report the average error of the ``final`` epoch.
### dSprites error on ResNet-18
| Methods | Avg | C → N | C → S | N → C | N → S | S → C | S → N |
|-------------|-------|-------|-------|-------|-------|-------|-------|
| ERM | 0.157 | 0.232 | 0.271 | 0.081 | 0.22 | 0.038 | 0.092 |
| DD | 0.057 | 0.047 | 0.08 | 0.03 | 0.095 | 0.053 | 0.037 |
### MPI3D error on ResNet-18
| Methods | Avg | RL → RC | RL → T | RC → RL | RC → T | T → RL | T → RC |
|-------------|-------|---------|--------|---------|--------|--------|--------|
| ERM | 0.176 | 0.232 | 0.271 | 0.081 | 0.22 | 0.038 | 0.092 |
| DD | 0.03 | 0.086 | 0.029 | 0.057 | 0.189 | 0.131 | 0.087 |
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{MDD,
title={Bridging theory and algorithm for domain adaptation},
author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael},
booktitle={ICML},
year={2019},
}
@inproceedings{RSD,
title={Representation Subspace Distance for Domain Adaptation Regression},
author={Chen, Xinyang and Wang, Sinan and Wang, Jianmin and Long, Mingsheng},
booktitle={ICML},
year={2021}
}
```
================================================
FILE: examples/domain_adaptation/image_regression/dann.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import utils
from tllib.modules.regressor import Regressor
from tllib.alignment.dann import DomainAdversarialLoss
from tllib.modules.domain_discriminator import DomainDiscriminator
import tllib.vision.datasets.regression as datasets
import tllib.vision.models as models
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = T.Compose([
T.Resize(args.resize_size),
T.ToTensor(),
normalize
])
val_transform = T.Compose([
T.Resize(args.resize_size),
T.ToTensor(),
normalize
])
dataset = datasets.__dict__[args.data]
train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True,
transform=train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True,
transform=train_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
if args.normalization == 'IN':
backbone = utils.convert_model(backbone)
num_factors = train_source_dataset.num_factors
bottleneck = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(backbone.out_features, 256),
nn.ReLU()
)
regressor = Regressor(backbone=backbone, num_factors=num_factors, bottleneck=bottleneck, bottleneck_dim=256).to(
device)
print(regressor)
domain_discri = DomainDiscriminator(in_feature=regressor.features_dim, hidden_size=1024).to(device)
# define optimizer and lr scheduler
optimizer = SGD(regressor.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum,
weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
dann = DomainAdversarialLoss(domain_discri).to(device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
regressor.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
# extract features from both domains
feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)
print(mae)
return
# start training
best_mae = 100000.
for epoch in range(args.epochs):
# train for one epoch
print("lr", lr_scheduler.get_lr())
train(train_source_iter, train_target_iter, regressor, dann, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)
# remember best mae and save checkpoint
torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest'))
if mae < best_mae:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_mae = min(mae, best_mae)
print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae))
print("best_mae = {:6.3f}".format(best_mae))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: Regressor, domain_adv: DomainAdversarialLoss, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
mse_losses = AverageMeter('MSE Loss', ':6.3f')
dann_losses = AverageMeter('DANN Loss', ':6.3f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, mse_losses, dann_losses, mae_losses_s, mae_losses_t, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, labels_s = next(train_source_iter)
x_s = x_s.to(device)
labels_s = labels_s.to(device).float()
x_t, labels_t = next(train_target_iter)
x_t = x_t.to(device)
labels_t = labels_t.to(device).float()
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
y_t, f_t = model(x_t)
mse_loss = F.mse_loss(y_s, labels_s)
mae_loss_s = F.l1_loss(y_s, labels_s)
mae_loss_t = F.l1_loss(y_t, labels_t)
transfer_loss = domain_adv(f_s, f_t)
loss = mse_loss + transfer_loss * args.trade_off
domain_acc = domain_adv.domain_discriminator_accuracy
mse_losses.update(mse_loss.item(), x_s.size(0))
dann_losses.update(transfer_loss.item(), x_s.size(0))
mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
mae_losses_t.update(mae_loss_t.item(), x_s.size(0))
domain_accs.update(domain_acc.item(), x_s.size(0))
# compute gradient and do SGD step
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='DANN for Regression Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='DSprites',
help='dataset: ' + ' | '.join(dataset_names) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-size', type=int, default=128)
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: resnet18)')
parser.add_argument('--normalization', default='BN', type=str, choices=["BN", "IN"])
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.001, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='dann',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_regression/dann.sh
================================================
# DSprites
CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_C2N
CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_C2S
CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_N2C
CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_N2S
CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_S2C
CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_S2N
# MPI3D
CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RL2RC --resize-size 224
CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RL2T --resize-size 224
CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RC2RL --resize-size 224
CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RC2T --resize-size 224
CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_T2RL --resize-size 224
CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_T2RC --resize-size 224
================================================
FILE: examples/domain_adaptation/image_regression/dd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import torch.nn as nn
import utils
from tllib.alignment.mdd import RegressionMarginDisparityDiscrepancy as MarginDisparityDiscrepancy, ImageRegressor
import tllib.vision.datasets.regression as datasets
import tllib.vision.models as models
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = T.Compose([
T.Resize(args.resize_size),
T.ToTensor(),
normalize
])
val_transform = T.Compose([
T.Resize(args.resize_size),
T.ToTensor(),
normalize
])
dataset = datasets.__dict__[args.data]
train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True,
transform=train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True,
transform=train_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
num_factors = train_source_dataset.num_factors
backbone = models.__dict__[args.arch](pretrained=True)
bottleneck_dim = args.bottleneck_dim
if args.normalization == 'IN':
backbone = utils.convert_model(backbone)
bottleneck = nn.Sequential(
nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(bottleneck_dim),
nn.ReLU(),
)
head = nn.Sequential(
nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(bottleneck_dim),
nn.ReLU(),
nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(bottleneck_dim),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(bottleneck_dim, num_factors),
nn.Sigmoid()
)
for layer in head:
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, 0, 0.01)
nn.init.constant_(layer.bias, 0)
adv_head = nn.Sequential(
nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(bottleneck_dim),
nn.ReLU(),
nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.InstanceNorm2d(bottleneck_dim),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(bottleneck_dim, num_factors),
nn.Sigmoid()
)
for layer in adv_head:
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, 0, 0.01)
nn.init.constant_(layer.bias, 0)
regressor = ImageRegressor(backbone, num_factors, bottleneck=bottleneck, head=head, adv_head=adv_head,
bottleneck_dim=bottleneck_dim, width=bottleneck_dim)
else:
regressor = ImageRegressor(backbone, num_factors,
bottleneck_dim=bottleneck_dim, width=bottleneck_dim)
regressor = regressor.to(device)
print(regressor)
mdd = MarginDisparityDiscrepancy(args.margin).to(device)
# define optimizer and lr scheduler
optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
regressor.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
# extract features from both domains
feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck, regressor.head[:-2]).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)
print(mae)
return
# start training
best_mae = 100000.
for epoch in range(args.epochs):
# train for one epoch
print("lr", lr_scheduler.get_lr())
train(train_source_iter, train_target_iter, regressor, mdd, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)
# remember best mae and save checkpoint
torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest'))
if mae < best_mae:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_mae = min(mae, best_mae)
print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae))
print("best_mae = {:6.3f}".format(best_mae))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model, mdd: MarginDisparityDiscrepancy, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
source_losses = AverageMeter('Source Loss', ':6.3f')
trans_losses = AverageMeter('Trans Loss', ':6.3f')
mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, source_losses, trans_losses, mae_losses_s, mae_losses_t],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
mdd.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, labels_s = next(train_source_iter)
x_s = x_s.to(device)
labels_s = labels_s.to(device).float()
x_t, labels_t = next(train_target_iter)
x_t = x_t.to(device)
labels_t = labels_t.to(device).float()
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat([x_s, x_t], dim=0)
outputs, outputs_adv = model(x)
y_s, y_t = outputs.chunk(2, dim=0)
y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0)
# compute mean square loss on source domain
mse_loss = F.mse_loss(y_s, labels_s)
# compute margin disparity discrepancy between domains
transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv)
# for adversarial classifier, minimize negative mdd is equal to maximize mdd
loss = mse_loss - transfer_loss * args.trade_off
model.step()
mae_loss_s = F.l1_loss(y_s, labels_s)
mae_loss_t = F.l1_loss(y_t, labels_t)
source_losses.update(mse_loss.item(), x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
mae_losses_t.update(mae_loss_t.item(), x_s.size(0))
# compute gradient and do SGD step
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='DD for Regression Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='DSprites',
help='dataset: ' + ' | '.join(dataset_names) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-size', type=int, default=128)
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: resnet18)')
parser.add_argument('--bottleneck-dim', default=512, type=int)
parser.add_argument('--normalization', default='BN', type=str, choices=['BN', 'IN'])
parser.add_argument('--margin', type=float, default=1., help="margin gamma")
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='dd',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_regression/dd.sh
================================================
# DSprites
CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_C2N --wd 0.0005
CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_C2S --wd 0.0005
CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_N2C --wd 0.0005
CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_N2S --wd 0.0005
CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_S2C --wd 0.0005
CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_S2N --wd 0.0005
# MPI3D
CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RL2RC --normalization IN --resize-size 224 --weight-decay 0.001
CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RL2T --normalization IN --resize-size 224 --weight-decay 0.001
CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RC2RL --normalization IN --resize-size 224 --weight-decay 0.001
CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RC2T --normalization IN --resize-size 224 --weight-decay 0.001
CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_T2RL --normalization IN --resize-size 224 --weight-decay 0.001
CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_T2RC --normalization IN --resize-size 224 --weight-decay 0.001
================================================
FILE: examples/domain_adaptation/image_regression/erm.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import utils
from tllib.modules.regressor import Regressor
import tllib.vision.datasets.regression as datasets
import tllib.vision.models as models
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = T.Compose([
T.Resize(args.resize_size),
T.ToTensor(),
normalize
])
val_transform = T.Compose([
T.Resize(args.resize_size),
T.ToTensor(),
normalize
])
dataset = datasets.__dict__[args.data]
train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True,
transform=train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True,
transform=train_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
if args.normalization == 'IN':
backbone = utils.convert_model(backbone)
num_factors = train_source_dataset.num_factors
regressor = Regressor(backbone=backbone, num_factors=num_factors).to(device)
print(regressor)
# define optimizer and lr scheduler
optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
regressor.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
# extract features from both domains
feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)
print(mae)
return
# start training
best_mae = 100000.
for epoch in range(args.epochs):
# train for one epoch
print("lr", lr_scheduler.get_lr())
train(train_source_iter, train_target_iter, regressor, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)
# remember best mae and save checkpoint
torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest'))
if mae < best_mae:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_mae = min(mae, best_mae)
print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae))
print("best_mae = {:6.3f}".format(best_mae))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: Regressor, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
mse_losses = AverageMeter('MSE Loss', ':6.3f')
mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, mse_losses, mae_losses_s, mae_losses_t],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, labels_s = next(train_source_iter)
x_s = x_s.to(device)
labels_s = labels_s.to(device).float()
x_t, labels_t = next(train_target_iter)
x_t = x_t.to(device)
labels_t = labels_t.to(device).float()
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, _ = model(x_s)
y_t, _ = model(x_t)
mse_loss = F.mse_loss(y_s, labels_s)
mae_loss_s = F.l1_loss(y_s, labels_s)
mae_loss_t = F.l1_loss(y_t, labels_t)
loss = mse_loss
mse_losses.update(mse_loss.item(), x_s.size(0))
mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
mae_losses_t.update(mae_loss_t.item(), x_s.size(0))
# compute gradient and do SGD step
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='Source Only for Regression Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='DSprites',
help='dataset: ' + ' | '.join(dataset_names) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-size', type=int, default=128)
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: resnet18)')
parser.add_argument('--normalization', default='BN', type=str, choices=["IN", "BN"])
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='src_only',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_regression/erm.sh
================================================
# DSprites
CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_C2N
CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_C2S
CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_N2C
CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_N2S
CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_S2C
CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_S2N
# MPI3D
CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RL2RC
CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RL2T
CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RC2RL
CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RC2T
CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_T2RL
CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_T2RC
================================================
FILE: examples/domain_adaptation/image_regression/rsd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
import utils
from tllib.modules.regressor import Regressor
from tllib.alignment.rsd import RepresentationSubspaceDistance
import tllib.vision.datasets.regression as datasets
import tllib.vision.models as models
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_transform = T.Compose([
T.Resize(args.resize_size),
T.ToTensor(),
normalize
])
val_transform = T.Compose([
T.Resize(args.resize_size),
T.ToTensor(),
normalize
])
dataset = datasets.__dict__[args.data]
train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True,
transform=train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True,
transform=train_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](pretrained=True)
if args.normalization == 'IN':
backbone = utils.convert_model(backbone)
num_factors = train_source_dataset.num_factors
bottleneck = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(backbone.out_features, 256),
nn.ReLU()
)
regressor = Regressor(backbone=backbone, num_factors=num_factors, bottleneck=bottleneck,
bottleneck_dim=256).to(device)
print(regressor)
# define optimizer and lr scheduler
optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
rsd = RepresentationSubspaceDistance(args.trade_off_bmp)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
regressor.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
# extract features from both domains
feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)
print(mae)
return
# start training
best_mae = 100000.
for epoch in range(args.epochs):
# train for one epoch
print("lr", lr_scheduler.get_lr())
train(train_source_iter, train_target_iter, regressor, rsd, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device)
# remember best mae and save checkpoint
torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest'))
if mae < best_mae:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_mae = min(mae, best_mae)
print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae))
print("best_mae = {:6.3f}".format(best_mae))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: Regressor, rsd, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
mse_losses = AverageMeter('MSE Loss', ':6.3f')
rsd_losses = AverageMeter('RSD Loss', ':6.3f')
mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f')
mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, mse_losses, rsd_losses, mae_losses_s, mae_losses_t],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, labels_s = next(train_source_iter)
x_s = x_s.to(device)
labels_s = labels_s.to(device).float()
x_t, labels_t = next(train_target_iter)
x_t = x_t.to(device)
labels_t = labels_t.to(device).float()
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
y_t, f_t = model(x_t)
mse_loss = F.mse_loss(y_s, labels_s)
mae_loss_s = F.l1_loss(y_s, labels_s)
mae_loss_t = F.l1_loss(y_t, labels_t)
rsd_loss = rsd(f_s, f_t)
loss = mse_loss + rsd_loss * args.trade_off
mse_losses.update(mse_loss.item(), x_s.size(0))
rsd_losses.update(rsd_loss.item(), x_s.size(0))
mae_losses_s.update(mae_loss_s.item(), x_s.size(0))
mae_losses_t.update(mae_loss_t.item(), x_s.size(0))
# compute gradient and do SGD step
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='RSD for Regression Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='DSprites',
help='dataset: ' + ' | '.join(dataset_names) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-size', type=int, default=128)
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: resnet18)')
parser.add_argument('--normalization', default='BN', type=str, choices=["BN", "IN"])
parser.add_argument('--trade-off', default=0.001, type=float)
parser.add_argument('--trade-off-bmp', default=0.1, type=float)
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.001, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='rsd',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/image_regression/rsd.sh
================================================
# DSprites
CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_C2N
CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_C2S
CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_N2C
CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_N2S
CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_S2C
CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_S2N
# MPI3D
CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RL2RC --resize-size 224
CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RL2T --resize-size 224
CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RC2RL --resize-size 224
CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RC2T --resize-size 224
CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_T2RL --resize-size 224
CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_T2RC --resize-size 224
================================================
FILE: examples/domain_adaptation/image_regression/utils.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import sys
import time
import torch
import torch.nn.functional as F
from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d
from torch.nn.modules.instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d
sys.path.append('../../..')
from tllib.utils.meter import AverageMeter, ProgressMeter
def convert_model(module):
"""convert BatchNorms in the `module` into InstanceNorms"""
source_modules = (BatchNorm1d, BatchNorm2d, BatchNorm3d)
target_modules = (InstanceNorm1d, InstanceNorm2d, InstanceNorm3d)
for src_module, tgt_module in zip(source_modules, target_modules):
if isinstance(module, src_module):
mod = tgt_module(module.num_features, module.eps, module.momentum, module.affine)
module = mod
for name, child in module.named_children():
module.add_module(name, convert_model(child))
return module
def validate(val_loader, model, args, factors, device):
batch_time = AverageMeter('Time', ':6.3f')
mae_losses = [AverageMeter('mae {}'.format(factor), ':6.3f') for factor in factors]
progress = ProgressMeter(
len(val_loader),
[batch_time] + mae_losses,
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
for j in range(len(factors)):
mae_loss = F.l1_loss(output[:, j], target[:, j])
mae_losses[j].update(mae_loss.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
for i, factor in enumerate(factors):
print("{} MAE {mae.avg:6.3f}".format(factor, mae=mae_losses[i]))
mean_mae = sum(l.avg for l in mae_losses) / len(factors)
return mean_mae
================================================
FILE: examples/domain_adaptation/keypoint_detection/README.md
================================================
# Unsupervised Domain Adaptation for Keypoint Detection
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to better reproduce the benchmark results.
## Dataset
Following datasets can be downloaded automatically:
- [Rendered Handpose Dataset](https://lmb.informatik.uni-freiburg.de/resources/datasets/RenderedHandposeDataset.en.html)
- [Hand-3d-Studio Dataset](https://www.yangangwang.com/papers/ZHAO-H3S-2020-02.html)
- [FreiHAND Dataset](https://lmb.informatik.uni-freiburg.de/projects/freihand/)
- [Surreal Dataset](https://www.di.ens.fr/willow/research/surreal/data/)
- [LSP Dataset](http://sam.johnson.io/research/lsp.html)
You need to prepare following datasets manually if you want to use them:
- [Human3.6M Dataset](http://vision.imar.ro/human3.6m/description.php)
and prepare them following [Documentations for Human3.6M Dataset](/common/vision/datasets/keypoint_detection/human36m.py).
## Supported Methods
Supported methods include:
- [Regressive Domain Adaptation for Unsupervised Keypoint Detection (RegDA, CVPR 2021)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/regressive-domain-adaptation-cvpr21.pdf)
## Experiment and Results
The shell files give the script to reproduce the results with specified hyper-parameters.
For example, if you want to train RegDA on RHD->H3D, use the following script
```shell script
# Train a RegDA on RHD -> H3D task using PoseResNet.
# Assume you have put the datasets under the path `data/RHD` and `data/H3D_crop`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python regda.py data/RHD data/H3D_crop \
-s RenderedHandPose -t Hand3DStudio --finetune --seed 0 --debug --log logs/regda/rhd2h3d
```
### RHD->H3D accuracy on ResNet-101
| Methods | MCP | PIP | DIP | Fingertip | Avg |
|-------------|------|------|------|-----------|------|
| ERM | 67.4 | 64.2 | 63.3 | 54.8 | 61.8 |
| RegDA | 79.6 | 74.4 | 71.2 | 62.9 | 72.5 |
| Oracle | 97.7 | 97.2 | 95.7 | 92.5 | 95.8 |
### Surreal->Human3.6M accuracy on ResNet-101
| Methods | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Avg |
|-------------|----------|-------|-------|------|------|-------|------|
| ERM | 69.4 | 75.4 | 66.4 | 37.9 | 77.3 | 77.7 | 67.3 |
| RegDA | 73.3 | 86.4 | 72.8 | 54.8 | 82.0 | 84.4 | 75.6 |
| Oracle | 95.3 | 91.8 | 86.9 | 95.6 | 94.1 | 93.6 | 92.9 |
### Surreal->LSP accuracy on ResNet-101
| Methods | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Avg |
|-------------|----------|-------|-------|------|------|-------|------|
| ERM | 51.5 | 65.0 | 62.9 | 68.0 | 68.7 | 67.4 | 63.9 |
| RegDA | 62.7 | 76.7 | 71.1 | 81.0 | 80.3 | 75.3 | 74.6 |
| Oracle | 95.3 | 91.8 | 86.9 | 95.6 | 94.1 | 93.6 | 92.9 |
## Visualization
If you want to visualize the keypoint detection results during training, you should set --debug.
```
CUDA_VISIBLE_DEVICES=0 python erm.py data/RHD data/H3D_crop -s RenderedHandPose -t Hand3DStudio --log logs/erm/rhd2h3d --debug --seed 0
```
Then you can find visualization images in directory ``logs/erm/rhd2h3d/visualize/``.
## TODO
Support methods: CycleGAN
## Citation
If you use these methods in your research, please consider citing.
```
@InProceedings{RegDA,
author = {Junguang Jiang and
Yifei Ji and
Ximei Wang and
Yufeng Liu and
Jianmin Wang and
Mingsheng Long},
title = {Regressive Domain Adaptation for Unsupervised Keypoint Detection},
booktitle = {CVPR},
year = {2021}
}
```
================================================
FILE: examples/domain_adaptation/keypoint_detection/erm.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import torch
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToPILImage
sys.path.append('../../..')
import tllib.vision.models.keypoint_detection as models
from tllib.vision.models.keypoint_detection.loss import JointsMSELoss
import tllib.vision.datasets.keypoint_detection as datasets
import tllib.vision.transforms.keypoint_detection as T
from tllib.vision.transforms import Denormalize
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict
from tllib.utils.metric.keypoint_detection import accuracy
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_transform = T.Compose([
T.RandomRotation(args.rotation),
T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
T.GaussianBlur(),
T.ToTensor(),
normalize
])
val_transform = T.Compose([
T.Resize(args.image_size),
T.ToTensor(),
normalize
])
image_size = (args.image_size, args.image_size)
heatmap_size = (args.heatmap_size, args.heatmap_size)
source_dataset = datasets.__dict__[args.source]
train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform,
image_size=image_size, heatmap_size=heatmap_size)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform,
image_size=image_size, heatmap_size=heatmap_size)
val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
target_dataset = datasets.__dict__[args.target]
train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform,
image_size=image_size, heatmap_size=heatmap_size)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform,
image_size=image_size, heatmap_size=heatmap_size)
val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
print("Source train:", len(train_source_loader))
print("Target train:", len(train_target_loader))
print("Source test:", len(val_source_loader))
print("Target test:", len(val_target_loader))
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
model = models.__dict__[args.arch](num_keypoints=train_source_dataset.num_keypoints).to(device)
criterion = JointsMSELoss()
# define optimizer and lr scheduler
optimizer = Adam(model.get_parameters(lr=args.lr))
lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)
# optionally resume from a checkpoint
start_epoch = 0
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
start_epoch = checkpoint['epoch'] + 1
# define visualization function
tensor_to_image = Compose([
Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
ToPILImage()
])
def visualize(image, keypoint2d, name):
"""
Args:
image (tensor): image in shape 3 x H x W
keypoint2d (tensor): keypoints in shape K x 2
name: name of the saving image
"""
train_source_dataset.visualize(tensor_to_image(image),
keypoint2d, logger.get_image_path("{}.jpg".format(name)))
if args.phase == 'test':
# evaluate on validation set
source_val_acc = validate(val_source_loader, model, criterion, None, args)
target_val_acc = validate(val_target_loader, model, criterion, visualize, args)
print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'], target_val_acc['all']))
for name, acc in target_val_acc.items():
print("{}: {:4.3f}".format(name, acc))
return
# start training
best_acc = 0
for epoch in range(start_epoch, args.epochs):
logger.set_epoch(epoch)
lr_scheduler.step()
# train for one epoch
train(train_source_iter, train_target_iter, model, criterion, optimizer, epoch,
visualize if args.debug else None, args)
# evaluate on validation set
source_val_acc = validate(val_source_loader, model, criterion, None, args)
target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args)
# remember best acc and save checkpoint
torch.save(
{
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if target_val_acc['all'] > best_acc:
shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))
best_acc = target_val_acc['all']
print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(source_val_acc['all'], target_val_acc['all'], best_acc))
for name, acc in target_val_acc.items():
print("{}: {:4.3f}".format(name, acc))
logger.close()
def train(train_source_iter, train_target_iter, model, criterion,
optimizer, epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_s = AverageMeter('Loss (s)', ":.2e")
acc_s = AverageMeter("Acc (s)", ":3.2f")
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_s, acc_s],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, label_s, weight_s, meta_s = next(train_source_iter)
x_s = x_s.to(device)
label_s = label_s.to(device)
weight_s = weight_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s = model(x_s)
loss_s = criterion(y_s, label_s, weight_s)
# compute gradient and do SGD step
loss_s.backward()
optimizer.step()
# measure accuracy and record loss
_, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
label_s.detach().cpu().numpy())
acc_s.update(avg_acc_s, cnt_s)
losses_s.update(loss_s, cnt_s)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred.jpg".format(i))
visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label.jpg".format(i))
def validate(val_loader, model, criterion, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.2e')
acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f")
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, acc['all']],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (x, label, weight, meta) in enumerate(val_loader):
x = x.to(device)
label = label.to(device)
weight = weight.to(device)
# compute output
y = model(x)
loss = criterion(y, label, weight)
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(),
label.cpu().numpy())
group_acc = val_loader.dataset.group_accuracy(acc_per_points)
acc.update(group_acc, x.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x[0], pred[0] * args.image_size / args.heatmap_size, "val_{}_pred.jpg".format(i))
visualize(x[0], meta['keypoint2d'][0], "val_{}_label.jpg".format(i))
return acc.average()
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='Source Only for Keypoint Detection Domain Adaptation')
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3),
help='scale range for the RandomResizeCrop augmentation')
parser.add_argument('--rotation', type=int, default=180,
help='rotation range of the RandomRotation augmentation')
parser.add_argument('--image-size', type=int, default=256,
help='input image size')
parser.add_argument('--heatmap-size', type=int, default=64,
help='output heatmap size')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='pose_resnet101',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: pose_resnet101)')
parser.add_argument("--resume", type=str, default=None,
help="where restore model parameters from.")
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler')
parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=70, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='src_only',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--debug', action="store_true",
help='In the debug mode, save images and predictions')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/keypoint_detection/erm.sh
================================================
# Source Only
# Hands Dataset
CUDA_VISIBLE_DEVICES=0 python erm.py data/RHD data/H3D_crop \
-s RenderedHandPose -t Hand3DStudio --log logs/erm/rhd2h3d --debug --seed 0
CUDA_VISIBLE_DEVICES=0 python erm.py data/FreiHand data/RHD \
-s FreiHand -t RenderedHandPose --log logs/erm/freihand2rhd --debug --seed 0
# Body Dataset
CUDA_VISIBLE_DEVICES=0 python erm.py data/surreal_processed data/Human36M \
-s SURREAL -t Human36M --log logs/erm/surreal2human36m --debug --seed 0 --rotation 30
CUDA_VISIBLE_DEVICES=0 python erm.py data/surreal_processed data/lsp \
-s SURREAL -t LSP --log logs/erm/surreal2lsp --debug --seed 0 --rotation 30
# Oracle Results
CUDA_VISIBLE_DEVICES=0 python erm.py data/H3D_crop data/H3D_crop \
-s Hand3DStudio -t Hand3DStudio --log logs/oracle/h3d --debug --seed 0
CUDA_VISIBLE_DEVICES=0 python erm.py data/Human36M data/Human36M \
-s Human36M -t Human36M --log logs/oracle/human36m --debug --seed 0 --rotation 30
================================================
FILE: examples/domain_adaptation/keypoint_detection/regda.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import torch
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR, MultiStepLR
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToPILImage
sys.path.append('../../..')
from tllib.alignment.regda import PoseResNet2d as RegDAPoseResNet, \
PseudoLabelGenerator2d, RegressionDisparity
import tllib.vision.models as models
from tllib.vision.models.keypoint_detection.pose_resnet import Upsampling, PoseResNet
from tllib.vision.models.keypoint_detection.loss import JointsKLLoss
import tllib.vision.datasets.keypoint_detection as datasets
import tllib.vision.transforms.keypoint_detection as T
from tllib.vision.transforms import Denormalize
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict
from tllib.utils.metric.keypoint_detection import accuracy
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_transform = T.Compose([
T.RandomRotation(args.rotation),
T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
T.GaussianBlur(),
T.ToTensor(),
normalize
])
val_transform = T.Compose([
T.Resize(args.image_size),
T.ToTensor(),
normalize
])
image_size = (args.image_size, args.image_size)
heatmap_size = (args.heatmap_size, args.heatmap_size)
source_dataset = datasets.__dict__[args.source]
train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform,
image_size=image_size, heatmap_size=heatmap_size)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform,
image_size=image_size, heatmap_size=heatmap_size)
val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
target_dataset = datasets.__dict__[args.target]
train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform,
image_size=image_size, heatmap_size=heatmap_size)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform,
image_size=image_size, heatmap_size=heatmap_size)
val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
print("Source train:", len(train_source_loader))
print("Target train:", len(train_target_loader))
print("Source test:", len(val_source_loader))
print("Target test:", len(val_target_loader))
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
backbone = models.__dict__[args.arch](pretrained=True)
upsampling = Upsampling(backbone.out_features)
num_keypoints = train_source_dataset.num_keypoints
model = RegDAPoseResNet(backbone, upsampling, 256, num_keypoints, num_head_layers=args.num_head_layers, finetune=True).to(device)
# define loss function
criterion = JointsKLLoss()
pseudo_label_generator = PseudoLabelGenerator2d(num_keypoints, args.heatmap_size, args.heatmap_size)
regression_disparity = RegressionDisparity(pseudo_label_generator, JointsKLLoss(epsilon=1e-7))
# define optimizer and lr scheduler
optimizer_f = SGD([
{'params': backbone.parameters(), 'lr': 0.1},
{'params': upsampling.parameters(), 'lr': 0.1},
], lr=0.1, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
optimizer_h = SGD(model.head.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True)
optimizer_h_adv = SGD(model.head_adv.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)
lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function)
lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function)
lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function)
start_epoch = 0
if args.resume is None:
if args.pretrain is None:
# first pretrain the backbone and upsampling
print("Pretraining the model on source domain.")
args.pretrain = logger.get_checkpoint_path('pretrain')
pretrained_model = PoseResNet(backbone, upsampling, 256, num_keypoints, True).to(device)
optimizer = SGD(pretrained_model.get_parameters(lr=args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)
best_acc = 0
for epoch in range(args.pretrain_epochs):
lr_scheduler.step()
print(lr_scheduler.get_lr())
pretrain(train_source_iter, pretrained_model, criterion, optimizer, epoch, args)
source_val_acc = validate(val_source_loader, pretrained_model, criterion, None, args)
# remember best acc and save checkpoint
if source_val_acc['all'] > best_acc:
best_acc = source_val_acc['all']
torch.save(
{
'model': pretrained_model.state_dict()
}, args.pretrain
)
print("Source: {} best: {}".format(source_val_acc['all'], best_acc))
# load from the pretrained checkpoint
pretrained_dict = torch.load(args.pretrain, map_location='cpu')['model']
model_dict = model.state_dict()
# remove keys from pretrained dict that doesn't appear in model dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model.load_state_dict(pretrained_dict, strict=False)
else:
# optionally resume from a checkpoint
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer_f.load_state_dict(checkpoint['optimizer_f'])
optimizer_h.load_state_dict(checkpoint['optimizer_h'])
optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv'])
lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f'])
lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h'])
lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv'])
start_epoch = checkpoint['epoch'] + 1
# define visualization function
tensor_to_image = Compose([
Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
ToPILImage()
])
def visualize(image, keypoint2d, name, heatmaps=None):
"""
Args:
image (tensor): image in shape 3 x H x W
keypoint2d (tensor): keypoints in shape K x 2
name: name of the saving image
"""
train_source_dataset.visualize(tensor_to_image(image),
keypoint2d, logger.get_image_path("{}.jpg".format(name)))
if args.phase == 'test':
# evaluate on validation set
source_val_acc = validate(val_source_loader, model, criterion, None, args)
target_val_acc = validate(val_target_loader, model, criterion, visualize, args)
print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'], target_val_acc['all']))
for name, acc in target_val_acc.items():
print("{}: {:4.3f}".format(name, acc))
return
# start training
best_acc = 0
print("Start regression domain adaptation.")
for epoch in range(start_epoch, args.epochs):
logger.set_epoch(epoch)
print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(), lr_scheduler_h_adv.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, model, criterion, regression_disparity,
optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv,
epoch, visualize if args.debug else None, args)
# evaluate on validation set
source_val_acc = validate(val_source_loader, model, criterion, None, args)
target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args)
# remember best acc and save checkpoint
torch.save(
{
'model': model.state_dict(),
'optimizer_f': optimizer_f.state_dict(),
'optimizer_h': optimizer_h.state_dict(),
'optimizer_h_adv': optimizer_h_adv.state_dict(),
'lr_scheduler_f': lr_scheduler_f.state_dict(),
'lr_scheduler_h': lr_scheduler_h.state_dict(),
'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if target_val_acc['all'] > best_acc:
shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))
best_acc = target_val_acc['all']
print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(source_val_acc['all'], target_val_acc['all'], best_acc))
for name, acc in target_val_acc.items():
print("{}: {:4.3f}".format(name, acc))
logger.close()
def pretrain(train_source_iter, model, criterion, optimizer,
epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_s = AverageMeter('Loss (s)', ":.2e")
acc_s = AverageMeter("Acc (s)", ":3.2f")
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_s, acc_s],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, label_s, weight_s, meta_s = next(train_source_iter)
x_s = x_s.to(device)
label_s = label_s.to(device)
weight_s = weight_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s = model(x_s)
loss_s = criterion(y_s, label_s, weight_s)
# compute gradient and do SGD step
loss_s.backward()
optimizer.step()
# measure accuracy and record loss
_, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
label_s.detach().cpu().numpy())
acc_s.update(avg_acc_s, cnt_s)
losses_s.update(loss_s, cnt_s)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
def train(train_source_iter, train_target_iter, model, criterion,regression_disparity,
optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv,
epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_s = AverageMeter('Loss (s)', ":.2e")
losses_gf = AverageMeter('Loss (t, false)', ":.2e")
losses_gt = AverageMeter('Loss (t, truth)', ":.2e")
acc_s = AverageMeter("Acc (s)", ":3.2f")
acc_t = AverageMeter("Acc (t)", ":3.2f")
acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f")
acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f")
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t, acc_s_adv, acc_t_adv],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, label_s, weight_s, meta_s = next(train_source_iter)
x_t, label_t, weight_t, meta_t = next(train_target_iter)
x_s = x_s.to(device)
label_s = label_s.to(device)
weight_s = weight_s.to(device)
x_t = x_t.to(device)
label_t = label_t.to(device)
weight_t = weight_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# Step A train all networks to minimize loss on source domain
optimizer_f.zero_grad()
optimizer_h.zero_grad()
optimizer_h_adv.zero_grad()
y_s, y_s_adv = model(x_s)
loss_s = criterion(y_s, label_s, weight_s) + \
args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')
loss_s.backward()
optimizer_f.step()
optimizer_h.step()
optimizer_h_adv.step()
# Step B train adv regressor to maximize regression disparity
optimizer_h_adv.zero_grad()
y_t, y_t_adv = model(x_t)
loss_ground_false = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='max')
loss_ground_false.backward()
optimizer_h_adv.step()
# Step C train feature extractor to minimize regression disparity
optimizer_f.zero_grad()
y_t, y_t_adv = model(x_t)
loss_ground_truth = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='min')
loss_ground_truth.backward()
optimizer_f.step()
# do update step
model.step()
lr_scheduler_f.step()
lr_scheduler_h.step()
lr_scheduler_h_adv.step()
# measure accuracy and record loss
_, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
label_s.detach().cpu().numpy())
acc_s.update(avg_acc_s, cnt_s)
_, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),
label_t.detach().cpu().numpy())
acc_t.update(avg_acc_t, cnt_t)
_, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(y_s_adv.detach().cpu().numpy(),
label_s.detach().cpu().numpy())
acc_s_adv.update(avg_acc_s_adv, cnt_s)
_, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(y_t_adv.detach().cpu().numpy(),
label_t.detach().cpu().numpy())
acc_t_adv.update(avg_acc_t_adv, cnt_t)
losses_s.update(loss_s, cnt_s)
losses_gf.update(loss_ground_false, cnt_s)
losses_gt.update(loss_ground_truth, cnt_s)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred".format(i))
visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label".format(i))
visualize(x_t[0], pred_t[0] * args.image_size / args.heatmap_size, "target_{}_pred".format(i))
visualize(x_t[0], meta_t['keypoint2d'][0], "target_{}_label".format(i))
visualize(x_s[0], pred_s_adv[0] * args.image_size / args.heatmap_size, "source_adv_{}_pred".format(i))
visualize(x_t[0], pred_t_adv[0] * args.image_size / args.heatmap_size, "target_adv_{}_pred".format(i))
def validate(val_loader, model, criterion, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.2e')
acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f")
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, acc['all']],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (x, label, weight, meta) in enumerate(val_loader):
x = x.to(device)
label = label.to(device)
weight = weight.to(device)
# compute output
y = model(x)
loss = criterion(y, label, weight)
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(),
label.cpu().numpy())
group_acc = val_loader.dataset.group_accuracy(acc_per_points)
acc.update(group_acc, x.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x[0], pred[0] * args.image_size / args.heatmap_size, "val_{}_pred.jpg".format(i))
visualize(x[0], meta['keypoint2d'][0], "val_{}_label.jpg".format(i))
return acc.average()
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='RegDA for Keypoint Detection Domain Adaptation')
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3),
help='scale range for the RandomResizeCrop augmentation')
parser.add_argument('--rotation', type=int, default=180,
help='rotation range of the RandomRotation augmentation')
parser.add_argument('--image-size', type=int, default=256,
help='input image size')
parser.add_argument('--heatmap-size', type=int, default=64,
help='output heatmap size')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet101',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: resnet101)')
parser.add_argument("--pretrain", type=str, default=None,
help="Where restore pretrained model parameters from.")
parser.add_argument("--resume", type=str, default=None,
help="where restore model parameters from.")
parser.add_argument('--num-head-layers', type=int, default=2)
parser.add_argument('--margin', type=float, default=4., help="margin gamma")
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0001, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--lr-gamma', default=0.0001, type=float)
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler')
parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--pretrain_epochs', default=70, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--epochs', default=30, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='regda',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--debug', action="store_true",
help='In the debug mode, save images and predictions')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/keypoint_detection/regda.sh
================================================
# Hands Dataset
CUDA_VISIBLE_DEVICES=0 python regda.py data/RHD data/H3D_crop \
-s RenderedHandPose -t Hand3DStudio --seed 0 --debug --log logs/regda/rhd2h3d
CUDA_VISIBLE_DEVICES=0 python regda.py data/FreiHand data/RHD \
-s FreiHand -t RenderedHandPose --seed 0 --debug --log logs/regda/freihand2rhd
# Body Dataset
CUDA_VISIBLE_DEVICES=0 python regda.py data/surreal_processed data/Human36M \
-s SURREAL -t Human36M --seed 0 --debug --rotation 30 --epochs 10 --log logs/regda/surreal2human36m
CUDA_VISIBLE_DEVICES=0 python regda.py data/surreal_processed data/lsp \
-s SURREAL -t LSP --seed 0 --debug --rotation 30 --log logs/regda/surreal2lsp
================================================
FILE: examples/domain_adaptation/keypoint_detection/regda_fast.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import torch
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR, MultiStepLR
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToPILImage
sys.path.append('../../..')
from tllib.alignment.regda import PoseResNet2d as RegDAPoseResNet, \
FastPseudoLabelGenerator2d, RegressionDisparity
import tllib.vision.models as models
from tllib.vision.models.keypoint_detection.pose_resnet import Upsampling, PoseResNet
from tllib.vision.models.keypoint_detection.loss import JointsKLLoss
import tllib.vision.datasets.keypoint_detection as datasets
import tllib.vision.transforms.keypoint_detection as T
from tllib.vision.transforms import Denormalize
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict
from tllib.utils.metric.keypoint_detection import accuracy
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_transform = T.Compose([
T.RandomRotation(args.rotation),
T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
T.GaussianBlur(),
T.ToTensor(),
normalize
])
val_transform = T.Compose([
T.Resize(args.image_size),
T.ToTensor(),
normalize
])
image_size = (args.image_size, args.image_size)
heatmap_size = (args.heatmap_size, args.heatmap_size)
source_dataset = datasets.__dict__[args.source]
train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform,
image_size=image_size, heatmap_size=heatmap_size)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform,
image_size=image_size, heatmap_size=heatmap_size)
val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
target_dataset = datasets.__dict__[args.target]
train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform,
image_size=image_size, heatmap_size=heatmap_size)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform,
image_size=image_size, heatmap_size=heatmap_size)
val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)
print("Source train:", len(train_source_loader))
print("Target train:", len(train_target_loader))
print("Source test:", len(val_source_loader))
print("Target test:", len(val_target_loader))
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
backbone = models.__dict__[args.arch](pretrained=True)
upsampling = Upsampling(backbone.out_features)
num_keypoints = train_source_dataset.num_keypoints
model = RegDAPoseResNet(backbone, upsampling, 256, num_keypoints, num_head_layers=args.num_head_layers, finetune=True).to(device)
# define loss function
criterion = JointsKLLoss()
pseudo_label_generator = FastPseudoLabelGenerator2d()
regression_disparity = RegressionDisparity(pseudo_label_generator, JointsKLLoss(epsilon=1e-7))
# define optimizer and lr scheduler
optimizer_f = SGD([
{'params': backbone.parameters(), 'lr': 0.1},
{'params': upsampling.parameters(), 'lr': 0.1},
], lr=0.1, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
optimizer_h = SGD(model.head.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True)
optimizer_h_adv = SGD(model.head_adv.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)
lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function)
lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function)
lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function)
start_epoch = 0
if args.resume is None:
if args.pretrain is None:
# first pretrain the backbone and upsampling
print("Pretraining the model on source domain.")
args.pretrain = logger.get_checkpoint_path('pretrain')
pretrained_model = PoseResNet(backbone, upsampling, 256, num_keypoints, True).to(device)
optimizer = SGD(pretrained_model.get_parameters(lr=args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)
best_acc = 0
for epoch in range(args.pretrain_epochs):
lr_scheduler.step()
print(lr_scheduler.get_lr())
pretrain(train_source_iter, pretrained_model, criterion, optimizer, epoch, args)
source_val_acc = validate(val_source_loader, pretrained_model, criterion, None, args)
# remember best acc and save checkpoint
if source_val_acc['all'] > best_acc:
best_acc = source_val_acc['all']
torch.save(
{
'model': pretrained_model.state_dict()
}, args.pretrain
)
print("Source: {} best: {}".format(source_val_acc['all'], best_acc))
# load from the pretrained checkpoint
pretrained_dict = torch.load(args.pretrain, map_location='cpu')['model']
model_dict = model.state_dict()
# remove keys from pretrained dict that doesn't appear in model dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model.load_state_dict(pretrained_dict, strict=False)
else:
# optionally resume from a checkpoint
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer_f.load_state_dict(checkpoint['optimizer_f'])
optimizer_h.load_state_dict(checkpoint['optimizer_h'])
optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv'])
lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f'])
lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h'])
lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv'])
start_epoch = checkpoint['epoch'] + 1
# define visualization function
tensor_to_image = Compose([
Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
ToPILImage()
])
def visualize(image, keypoint2d, name, heatmaps=None):
"""
Args:
image (tensor): image in shape 3 x H x W
keypoint2d (tensor): keypoints in shape K x 2
name: name of the saving image
"""
train_source_dataset.visualize(tensor_to_image(image),
keypoint2d, logger.get_image_path("{}.jpg".format(name)))
if args.phase == 'test':
# evaluate on validation set
source_val_acc = validate(val_source_loader, model, criterion, None, args)
target_val_acc = validate(val_target_loader, model, criterion, visualize, args)
print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'], target_val_acc['all']))
for name, acc in target_val_acc.items():
print("{}: {:4.3f}".format(name, acc))
return
# start training
best_acc = 0
print("Start regression domain adaptation.")
for epoch in range(start_epoch, args.epochs):
logger.set_epoch(epoch)
print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(), lr_scheduler_h_adv.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, model, criterion, regression_disparity,
optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv,
epoch, visualize if args.debug else None, args)
# evaluate on validation set
source_val_acc = validate(val_source_loader, model, criterion, None, args)
target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args)
# remember best acc and save checkpoint
torch.save(
{
'model': model.state_dict(),
'optimizer_f': optimizer_f.state_dict(),
'optimizer_h': optimizer_h.state_dict(),
'optimizer_h_adv': optimizer_h_adv.state_dict(),
'lr_scheduler_f': lr_scheduler_f.state_dict(),
'lr_scheduler_h': lr_scheduler_h.state_dict(),
'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if target_val_acc['all'] > best_acc:
shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))
best_acc = target_val_acc['all']
print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(source_val_acc['all'], target_val_acc['all'], best_acc))
for name, acc in target_val_acc.items():
print("{}: {:4.3f}".format(name, acc))
logger.close()
def pretrain(train_source_iter, model, criterion, optimizer,
epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_s = AverageMeter('Loss (s)', ":.2e")
acc_s = AverageMeter("Acc (s)", ":3.2f")
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_s, acc_s],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, label_s, weight_s, meta_s = next(train_source_iter)
x_s = x_s.to(device)
label_s = label_s.to(device)
weight_s = weight_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s = model(x_s)
loss_s = criterion(y_s, label_s, weight_s)
# compute gradient and do SGD step
loss_s.backward()
optimizer.step()
# measure accuracy and record loss
_, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
label_s.detach().cpu().numpy())
acc_s.update(avg_acc_s, cnt_s)
losses_s.update(loss_s, cnt_s)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
def train(train_source_iter, train_target_iter, model, criterion,regression_disparity,
optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv,
epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_s = AverageMeter('Loss (s)', ":.2e")
losses_gf = AverageMeter('Loss (t, false)', ":.2e")
losses_gt = AverageMeter('Loss (t, truth)', ":.2e")
acc_s = AverageMeter("Acc (s)", ":3.2f")
acc_t = AverageMeter("Acc (t)", ":3.2f")
acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f")
acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f")
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t, acc_s_adv, acc_t_adv],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, label_s, weight_s, meta_s = next(train_source_iter)
x_t, label_t, weight_t, meta_t = next(train_target_iter)
x_s = x_s.to(device)
label_s = label_s.to(device)
weight_s = weight_s.to(device)
x_t = x_t.to(device)
label_t = label_t.to(device)
weight_t = weight_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# Step A train all networks to minimize loss on source domain
optimizer_f.zero_grad()
optimizer_h.zero_grad()
optimizer_h_adv.zero_grad()
y_s, y_s_adv = model(x_s)
loss_s = criterion(y_s, label_s, weight_s) + \
args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min')
loss_s.backward()
optimizer_f.step()
optimizer_h.step()
optimizer_h_adv.step()
# Step B train adv regressor to maximize regression disparity
optimizer_h_adv.zero_grad()
y_t, y_t_adv = model(x_t)
loss_ground_false = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='max')
loss_ground_false.backward()
optimizer_h_adv.step()
# Step C train feature extractor to minimize regression disparity
optimizer_f.zero_grad()
y_t, y_t_adv = model(x_t)
loss_ground_truth = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='min')
loss_ground_truth.backward()
optimizer_f.step()
# do update step
model.step()
lr_scheduler_f.step()
lr_scheduler_h.step()
lr_scheduler_h_adv.step()
# measure accuracy and record loss
_, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(),
label_s.detach().cpu().numpy())
acc_s.update(avg_acc_s, cnt_s)
_, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(),
label_t.detach().cpu().numpy())
acc_t.update(avg_acc_t, cnt_t)
_, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(y_s_adv.detach().cpu().numpy(),
label_s.detach().cpu().numpy())
acc_s_adv.update(avg_acc_s_adv, cnt_s)
_, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(y_t_adv.detach().cpu().numpy(),
label_t.detach().cpu().numpy())
acc_t_adv.update(avg_acc_t_adv, cnt_t)
losses_s.update(loss_s, cnt_s)
losses_gf.update(loss_ground_false, cnt_s)
losses_gt.update(loss_ground_truth, cnt_s)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred".format(i))
visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label".format(i))
visualize(x_t[0], pred_t[0] * args.image_size / args.heatmap_size, "target_{}_pred".format(i))
visualize(x_t[0], meta_t['keypoint2d'][0], "target_{}_label".format(i))
visualize(x_s[0], pred_s_adv[0] * args.image_size / args.heatmap_size, "source_adv_{}_pred".format(i))
visualize(x_t[0], pred_t_adv[0] * args.image_size / args.heatmap_size, "target_adv_{}_pred".format(i))
def validate(val_loader, model, criterion, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.2e')
acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f")
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, acc['all']],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (x, label, weight, meta) in enumerate(val_loader):
x = x.to(device)
label = label.to(device)
weight = weight.to(device)
# compute output
y = model(x)
loss = criterion(y, label, weight)
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(),
label.cpu().numpy())
group_acc = val_loader.dataset.group_accuracy(acc_per_points)
acc.update(group_acc, x.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x[0], pred[0] * args.image_size / args.heatmap_size, "val_{}_pred.jpg".format(i))
visualize(x[0], meta['keypoint2d'][0], "val_{}_label.jpg".format(i))
return acc.average()
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='RegDA (fast) for Keypoint Detection Domain Adaptation')
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3),
help='scale range for the RandomResizeCrop augmentation')
parser.add_argument('--rotation', type=int, default=180,
help='rotation range of the RandomRotation augmentation')
parser.add_argument('--image-size', type=int, default=256,
help='input image size')
parser.add_argument('--heatmap-size', type=int, default=64,
help='output heatmap size')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet101',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: resnet101)')
parser.add_argument("--pretrain", type=str, default=None,
help="Where restore pretrained model parameters from.")
parser.add_argument("--resume", type=str, default=None,
help="where restore model parameters from.")
parser.add_argument('--num-head-layers', type=int, default=2)
parser.add_argument('--margin', type=float, default=4., help="margin gamma")
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0001, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--lr-gamma', default=0.0001, type=float)
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler')
parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--pretrain_epochs', default=70, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--epochs', default=30, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='regda_fast',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--debug', action="store_true",
help='In the debug mode, save images and predictions')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/keypoint_detection/regda_fast.sh
================================================
# regda_fast is provided by https://github.com/YouJiacheng?tab=repositories
# On single V100(16G), overall adversarial training time is reduced by about 40%.
# yet the PCK might drop 1% for each dataset.
# Hands Dataset
CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/RHD data/H3D_crop \
-s RenderedHandPose -t Hand3DStudio --seed 0 --debug --log logs/regda_fast/rhd2h3d
CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/FreiHand data/RHD \
-s FreiHand -t RenderedHandPose --seed 0 --debug --log logs/regda_fast/freihand2rhd
# Body Dataset
CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/surreal_processed data/Human36M \
-s SURREAL -t Human36M --seed 0 --debug --rotation 30 --epochs 10 --log logs/regda_fast/surreal2human36m
CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/surreal_processed data/lsp \
-s SURREAL -t LSP --seed 0 --debug --rotation 30 --log logs/regda_fast/surreal2lsp
================================================
FILE: examples/domain_adaptation/object_detection/README.md
================================================
# Unsupervised Domain Adaptation for Object Detection
## Updates
- *04/2022*: Provide CycleGAN translated datasets.
## Installation
Our code is based on [Detectron latest(v0.6)](https://detectron2.readthedocs.io/en/latest/tutorials/install.html), please install it before usage.
The following is an example based on PyTorch 1.9.0 with CUDA 11.1. For other versions, please refer to
the official website of [PyTorch](https://pytorch.org/) and
[Detectron](https://detectron2.readthedocs.io/en/latest/tutorials/install.html).
```shell
# create environment
conda create -n detection python=3.8.3
# activate environment
conda activate detection
# install pytorch
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
# install detectron
python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html
# install other requirements
pip install -r requirements.txt
```
## Dataset
Following datasets can be downloaded automatically:
- [PASCAL_VOC 07+12](http://host.robots.ox.ac.uk/pascal/VOC/)
- Clipart
- WaterColor
- Comic
You need to prepare following datasets manually if you want to use them:
#### Cityscapes, Foggy Cityscapes
- Download Cityscapes and Foggy Cityscapes dataset from the [link](https://www.cityscapes-dataset.com/downloads/). Particularly, we use *leftImg8bit_trainvaltest.zip* for Cityscapes and *leftImg8bit_trainvaltest_foggy.zip* for Foggy Cityscapes.
- Unzip them under the directory like
```
object_detction/datasets/cityscapes
├── gtFine
├── leftImg8bit
├── leftImg8bit_foggy
└── ...
```
Then run
```
python prepare_cityscapes_to_voc.py
```
This will automatically generate dataset in `VOC` format.
```
object_detction/datasets/cityscapes_in_voc
├── Annotations
├── ImageSets
└── JPEGImages
object_detction/datasets/foggy_cityscapes_in_voc
├── Annotations
├── ImageSets
└── JPEGImages
```
#### Sim10k
- Download Sim10k dataset from the following links: [Sim10k](https://fcav.engin.umich.edu/projects/driving-in-the-matrix). Particularly, we use *repro_10k_images.tgz* , *repro_image_sets.tgz* and *repro_10k_annotations.tgz* for Sim10k.
- Extract the training set from *repro_10k_images.tgz*, *repro_image_sets.tgz* and *repro_10k_annotations.tgz*, then rename directory `VOC2012/` to `sim10k/`.
After preparation, there should exist following files:
```
object_detction/datasets/
├── VOC2007
│ ├── Annotations
│ ├──ImageSets
│ └──JPEGImages
├── VOC2012
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages
├── clipart
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages
├── watercolor
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages
├── comic
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages
├── cityscapes_in_voc
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages
├── foggy_cityscapes_in_voc
│ ├── Annotations
│ ├── ImageSets
│ └── JPEGImages
└── sim10k
├── Annotations
├── ImageSets
└── JPEGImages
```
**Note**: The above is a tutorial for using standard datasets. To use your own datasets,
you need to convert them into corresponding format.
#### CycleGAN translated dataset
The following command use CycleGAN to translate VOC (with directory `datasets/VOC2007` and `datasets/VOC2012`) to Clipart (with directory `datasets/VOC2007_to_clipart` and `datasets/VOC2012_to_clipart`).
```
mkdir datasets/VOC2007_to_clipart
cp -r datasets/VOC2007/* datasets/VOC2007_to_clipart
mkdir datasets/VOC2012_to_clipart
cp -r datasets/VOC2012/* datasets/VOC2012_to_clipart
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py \
-s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \
--translated-source datasets/VOC2007_to_clipart datasets/VOC2012_to_clipart \
--log logs/cyclegan_resnet9/translation/voc2clipart --netG resnet_9
```
You can also download and use datasets that are translated by us.
- PASCAL_VOC to Clipart [[07]](https://cloud.tsinghua.edu.cn/f/1b6b060d202145aea416/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/818dbd8e41a043fab7c3/?dl=1) (with directory `datasets/VOC2007_to_clipart` and `datasets/VOC2012_to_clipart`)
- PASCAL_VOC to Comic [[07]](https://cloud.tsinghua.edu.cn/f/89382bba64514210a9f8/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/f90289137fd5465f806d/?dl=1) (with directory `datasets/VOC2007_to_comic` and `datasets/VOC2012_to_comic`)
- PASCAL_VOC to WaterColor [[07]](https://cloud.tsinghua.edu.cn/f/8e982e9f21294b38be8a/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/b8235034cb4247ce809f/?dl=1) (with directory `datasets/VOC2007_to_watercolor` and `datasets/VOC2012_to_watercolor`)
- Cityscapes to Foggy Cityscapes [[Part1]](https://cloud.tsinghua.edu.cn/f/09ceeb25a476481bae29/?dl=1) [[Part2]](https://cloud.tsinghua.edu.cn/f/51fb05d3ee614e7d87a0/?dl=1) [[Part3]](https://cloud.tsinghua.edu.cn/f/646415daf6b344c3a9e3/?dl=1) [[Part4]](https://cloud.tsinghua.edu.cn/f/008d5d3c54344f83b101/?dl=1) (with directory `datasets/cityscapes_to_foggy_cityscapes`). Note that you need to use ``cat`` to merge the downloaded files.
- Sim10k to Cityscapes (Car) [[Download]](https://cloud.tsinghua.edu.cn/f/33ac656fcde34f758dcd/?dl=1) (with directory `datasets/sim10k2cityscapes_car`).
## Supported Methods
Supported methods include:
- [Cycle-Consistent Adversarial Networks (CycleGAN)](https://arxiv.org/pdf/1703.10593.pdf)
- [Decoupled Adaptation for Cross-Domain Object Detection (D-adapt)](https://arxiv.org/abs/2110.02578)
## Experiment and Results
The shell files give the script to reproduce the [benchmarks](/docs/dalib/benchmarks/object_detection.rst) with specified hyper-parameters.
The basic training pipeline is as follows.
The following command trains a Faster-RCNN detector on task VOC->Clipart, with only source (VOC) data.
```
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \
--test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \
OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2clipart
```
Explanation of some arguments
- `--config-file`: path to config file that specifies training hyper-parameters.
- `-s`: a list that specifies source datasets, for each dataset you should pass in a `(name, path)` pair, in the
above command, there are two source datasets **VOC2007** and **VOC2012**.
- `-t`: a list that specifies target datasets, same format as above.
- `--test`: a list that specifiers test datasets, same format as above.
### VOC->Clipart
| | | AP | AP50 | AP75 | aeroplane | bicycle | bird | boat | bottle | bus | car | cat | chair | cow | diningtable | dog | horse | motorbike | person | pottedplant | sheep | sofa | train | tvmonitor |
|-------------------------|----------|------|------|------|-----------|---------|------|------|--------|------|------|------|-------|------|-------------|------|-------|-----------|--------|-------------|-------|------|-------|-----------|
| Faster RCNN (ResNet101) | Source | 14.9 | 29.3 | 12.6 | 29.6 | 38.0 | 24.7 | 21.7 | 31.9 | 48.0 | 30.8 | 15.9 | 32.0 | 19.2 | 18.2 | 12.1 | 28.2 | 48.8 | 38.3 | 34.6 | 3.8 | 22.5 | 43.7 | 44.0 |
| | CycleGAN | 20.0 | 37.7 | 18.3 | 37.1 | 41.9 | 29.9 | 26.5 | 40.9 | 65.1 | 37.8 | 23.8 | 40.7 | 48.9 | 12.7 | 14.4 | 27.8 | 63.0 | 55.1 | 40.1 | 8.0 | 30.7 | 54.1 | 55.7 |
| | D-adapt | 24.8 | 49.0 | 21.5 | 56.4 | 63.2 | 42.3 | 40.9 | 45.3 | 77.0 | 48.7 | 25.4 | 44.3 | 58.4 | 31.4 | 24.5 | 47.1 | 75.3 | 69.3 | 43.5 | 27.9 | 34.1 | 60.7 | 64.0 |
| | | | | | | | | | | | | | | | | | | | | | | | | |
| RetinaNet | Source | 18.3 | 32.2 | 17.6 | 34.2 | 42.4 | 27.0 | 21.6 | 36.8 | 48.4 | 35.9 | 16.4 | 38.9 | 22.6 | 27.0 | 15.1 | 27.1 | 46.7 | 42.1 | 36.2 | 8.3 | 29.5 | 42.1 | 46.2 |
| | D-adapt | 25.1 | 46.3 | 23.9 | 47.4 | 65.0 | 33.1 | 37.5 | 56.8 | 61.2 | 55.1 | 27.3 | 45.5 | 51.8 | 29.1 | 29.6 | 38.0 | 74.5 | 66.7 | 46.0 | 24.2 | 29.3 | 54.2 | 53.8 |
### VOC->WaterColor
| | AP | AP50 | AP75 | bicycle | bird | car | cat | dog | person |
|-------------------------|------|------|------|---------|------|------|------|------|--------|
| Faster RCNN (ResNet101) | 23.0 | 45.9 | 18.5 | 71.1 | 48.3 | 48.6 | 23.7 | 23.3 | 60.3 |
| CycleGAN | 24.9 | 50.8 | 22.4 | 75.8 | 52.1 | 49.8 | 30.1 | 33.4 | 63.6 |
| D-adapt | 28.5 | 57.5 | 23.6 | 77.4 | 54.0 | 52.8 | 43.9 | 48.1 | 68.9 |
| Target | 23.8 | 51.3 | 17.4 | 48.5 | 54.7 | 41.3 | 36.2 | 52.6 | 74.6 |
### VOC->Comic
| | AP | AP50 | AP75 | bicycle | bird | car | cat | dog | person |
|:-----------------------:|:----:|:----:|:----:|:-------:|:----:|:----:|:----:|:----:|:------:|
| Faster RCNN (ResNet101) | 13.0 | 25.5 | 11.4 | 33.0 | 15.8 | 28.9 | 16.8 | 19.6 | 39.0 |
| CycleGAN | 16.9 | 34.6 | 14.2 | 28.1 | 25.7 | 37.7 | 28.0 | 33.8 | 54.1 |
| D-adapt | 20.8 | 41.1 | 18.5 | 49.4 | 25.7 | 43.3 | 36.9 | 32.7 | 58.5 |
| Target | 21.9 | 44.6 | 16.0 | 40.7 | 32.3 | 38.3 | 43.9 | 41.3 | 71.0 |
### Cityscapes->Foggy Cityscapes
| | | AP | AP50 | AP75 | bicycle | bus | car | motorcycle | person | rider | train | truck |
|:-----------------------:|:--------:|:----:|:----:|:----:|:-------:|:----:|:----:|:----------:|:------:|:-----:|:-----:|:-----:|
| Faster RCNN (VGG16) | Source | 14.3 | 25.9 | 13.2 | 33.6 | 27.0 | 40.0 | 22.3 | 31.3 | 38.5 | 2.3 | 12.2 |
| | CycleGAN | 22.5 | 41.6 | 20.7 | 46.5 | 41.5 | 62.0 | 33.8 | 45.0 | 54.5 | 21.7 | 27.7 |
| | D-adapt | 19.4 | 38.1 | 17.5 | 42.0 | 36.8 | 58.1 | 32.2 | 43.1 | 51.8 | 14.6 | 26.3 |
| | Target | 24.0 | 45.3 | 21.3 | 45.9 | 47.4 | 67.3 | 39.7 | 49.0 | 53.2 | 30.0 | 29.6 |
| | | | | | | | | | | | | |
| Faster RCNN (ResNet101) | Source | 18.8 | 33.3 | 19.0 | 36.1 | 34.5 | 43.8 | 24.0 | 36.3 | 39.9 | 29.1 | 22.8 |
| | CycleGAN | 22.9 | 41.8 | 21.9 | 42.0 | 44.5 | 57.6 | 36.3 | 40.9 | 48.0 | 30.8 | 34.3 |
| | D-adapt | 22.7 | 42.4 | 21.6 | 41.8 | 44.4 | 56.6 | 31.4 | 41.8 | 48.6 | 42.3 | 32.4 |
| | Target | 25.5 | 45.3 | 24.3 | 41.9 | 53.2 | 63.4 | 36.1 | 42.6 | 47.9 | 42.4 | 35.3 |
### Sim10k->Cityscapes Car
| | | AP | AP50 | AP75 |
|:-----------------------:|:--------:|:----:|:----:|:----:|
| Faster RCNN (VGG16) | Source | 24.8 | 43.4 | 23.6 |
| | CycleGAN | 29.3 | 51.9 | 28.6 |
| | D-adapt | 23.6 | 48.5 | 18.7 |
| | Target | 24.8 | 43.4 | 23.6 |
| | | | | |
| Faster RCNN (ResNet101) | Source | 24.6 | 44.4 | 23.0 |
| | CycleGAN | 26.5 | 47.4 | 24.0 |
| | D-adapt | 27.4 | 51.9 | 25.7 |
| | Target | 24.6 | 44.4 | 23.0 |
### Visualization
We provide code for visualization in `visualize.py`. For example, suppose you have trained the source only model
of task VOC->Clipart using provided scripts. The following code visualizes the prediction of the
detector on Clipart.
```shell
CUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \
--test Clipart datasets/clipart --save-path visualizations/source_only/voc2clipart \
MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth
```
Explanation of some arguments
- `--test`: a list that specifiers test datasets for visualization.
- `--save-path`: where to save visualization results.
- `MODEL.WEIGHTS`: path to the model.
## TODO
Support methods: SWDA, Global/Local Alignment
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{jiang2021decoupled,
title = {Decoupled Adaptation for Cross-Domain Object Detection},
author = {Junguang Jiang and Baixu Chen and Jianmin Wang and Mingsheng Long},
booktitle = {ICLR},
year = {2022}
}
@inproceedings{CycleGAN,
title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},
author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
booktitle={ICCV},
year={2017}
}
```
================================================
FILE: examples/domain_adaptation/object_detection/config/faster_rcnn_R_101_C4_cityscapes.yaml
================================================
MODEL:
META_ARCHITECTURE: "TLGeneralizedRCNN"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
MASK_ON: False
RESNETS:
DEPTH: 101
ROI_HEADS:
NAME: "TLRes5ROIHeads"
NUM_CLASSES: 8
BATCH_SIZE_PER_IMAGE: 512
ANCHOR_GENERATOR:
SIZES: [ [ 64, 128, 256, 512 ] ]
RPN:
PRE_NMS_TOPK_TEST: 6000
POST_NMS_TOPK_TEST: 1000
BATCH_SIZE_PER_IMAGE: 256
PROPOSAL_GENERATOR:
NAME: "TLRPN"
INPUT:
MIN_SIZE_TRAIN: (512, 544, 576, 608, 640, 672, 704,)
MIN_SIZE_TEST: 608
MAX_SIZE_TRAIN: 1166
DATASETS:
TRAIN: ("cityscapes_trainval",)
TEST: ("cityscapes_test",)
SOLVER:
STEPS: (12000,)
MAX_ITER: 16000 # 16 epochs
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 2000
IMS_PER_BATCH: 2
BASE_LR: 0.005
TEST:
EVAL_PERIOD: 2000
VIS_PERIOD: 500
VERSION: 2
================================================
FILE: examples/domain_adaptation/object_detection/config/faster_rcnn_R_101_C4_voc.yaml
================================================
MODEL:
META_ARCHITECTURE: "TLGeneralizedRCNN"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
MASK_ON: False
RESNETS:
DEPTH: 101
ROI_HEADS:
NAME: "TLRes5ROIHeads"
NUM_CLASSES: 20
BATCH_SIZE_PER_IMAGE: 256
ANCHOR_GENERATOR:
SIZES: [ [ 64, 128, 256, 512 ] ]
RPN:
PRE_NMS_TOPK_TEST: 6000
POST_NMS_TOPK_TEST: 1000
BATCH_SIZE_PER_IMAGE: 128
PROPOSAL_GENERATOR:
NAME: "TLRPN"
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704,)
MIN_SIZE_TEST: 608
MAX_SIZE_TRAIN: 1166
DATASETS:
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
TEST: ('voc_2007_test',)
SOLVER:
STEPS: (12000, )
MAX_ITER: 16000 # 16 epochs
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 2000
IMS_PER_BATCH: 4
BASE_LR: 0.005
TEST:
EVAL_PERIOD: 2000
VIS_PERIOD: 500
VERSION: 2
================================================
FILE: examples/domain_adaptation/object_detection/config/faster_rcnn_vgg_16_cityscapes.yaml
================================================
MODEL:
META_ARCHITECTURE: "TLGeneralizedRCNN"
WEIGHTS: 'https://open-mmlab.oss-cn-beijing.aliyuncs.com/pretrain/vgg16_caffe-292e1171.pth'
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_ON: False
BACKBONE:
NAME: "build_vgg_fpn_backbone"
ROI_HEADS:
IN_FEATURES: ["p3", "p4", "p5", "p6"]
NAME: "TLStandardROIHeads"
NUM_CLASSES: 8
ROI_BOX_HEAD:
NAME: "FastRCNNConvFCHead"
NUM_FC: 2
POOLER_RESOLUTION: 7
ANCHOR_GENERATOR:
SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ] # One size for each in feature map
ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ] # Three aspect
RPN:
IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"]
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
# Detectron1 uses 2000 proposals per-batch,
# (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
# which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
POST_NMS_TOPK_TRAIN: 1000
POST_NMS_TOPK_TEST: 1000
PROPOSAL_GENERATOR:
NAME: "TLRPN"
INPUT:
FORMAT: "RGB"
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1280
MAX_SIZE_TRAIN: 1280
DATASETS:
TRAIN: ("cityscapes_trainval",)
TEST: ("cityscapes_test",)
SOLVER:
STEPS: (12000,)
MAX_ITER: 16000 # 16 epochs
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 2000
IMS_PER_BATCH: 8
BASE_LR: 0.01
TEST:
EVAL_PERIOD: 2000
VIS_PERIOD: 500
VERSION: 2
================================================
FILE: examples/domain_adaptation/object_detection/config/retinanet_R_101_FPN_voc.yaml
================================================
MODEL:
META_ARCHITECTURE: "TLRetinaNet"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
BACKBONE:
NAME: "build_retinanet_resnet_fpn_backbone"
MASK_ON: False
RESNETS:
DEPTH: 101
OUT_FEATURES: [ "res4", "res5" ]
ANCHOR_GENERATOR:
SIZES: !!python/object/apply:eval [ "[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [64, 128, 256, 512 ]]" ]
RETINANET:
NUM_CLASSES: 20
IN_FEATURES: ["p4", "p5", "p6", "p7"]
IOU_THRESHOLDS: [ 0.4, 0.5 ]
IOU_LABELS: [ 0, -1, 1 ]
SMOOTH_L1_LOSS_BETA: 0.0
FPN:
IN_FEATURES: ["res4", "res5"]
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, )
MIN_SIZE_TEST: 608
MAX_SIZE_TRAIN: 1166
DATASETS:
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
TEST: ('voc_2007_test',)
SOLVER:
STEPS: (12000, )
MAX_ITER: 16000 # 16 epochs
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 2000
IMS_PER_BATCH: 8
BASE_LR: 0.005
TEST:
EVAL_PERIOD: 2000
VIS_PERIOD: 500
VERSION: 2
================================================
FILE: examples/domain_adaptation/object_detection/cycle_gan.py
================================================
"""
CycleGAN for VOC-format Object Detection Dataset
You need to modify function build_dataset if you want to use your own dataset.
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import itertools
import os
import tqdm
from typing import Optional, Callable, Tuple, Any, List
from PIL import Image
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, ConcatDataset
from torchvision.transforms import ToPILImage, Compose
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader
import torchvision.transforms as T
sys.path.append('../../..')
import tllib.translation.cyclegan as cyclegan
from tllib.translation.cyclegan.util import ImagePool, set_requires_grad
from tllib.vision.transforms import Denormalize
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(max(round(oh / base), 1) * base)
w = int(max(round(ow / base), 1) * base)
if h == oh and w == ow:
return img
return img.resize((w, h), method)
class VOCImageFolder(datasets.VisionDataset):
"""A VOC-format Dataset class for image translation
"""
def __init__(self, root: str, phase='trainval',
transform: Optional[Callable] = None, extension='.jpg'):
super().__init__(root, transform=transform)
data_list_file = os.path.join(root, "ImageSets/Main/{}.txt".format(phase))
self.samples = self.parse_data_file(data_list_file, extension)
self.loader = default_loader
self.data_list_file = data_list_file
def __getitem__(self, index: int) -> Tuple[Any, str]:
"""
Args:
index (int): Index
return (tuple): (image, target) where target is index of the target class.
"""
path = self.samples[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
return img, path
def __len__(self) -> int:
return len(self.samples)
def parse_data_file(self, file_name: str, extension: str) -> List[str]:
"""Parse file to data list
Args:
file_name (str): The path of data file
return (list): List of (image path, class_index) tuples
"""
with open(file_name, "r") as f:
data_list = []
for line in f.readlines():
line = line.strip()
if extension is None:
path = line
else:
path = line + extension
if not os.path.isabs(path):
path = os.path.join(self.root, "JPEGImages", path)
data_list.append((path))
return data_list
def translate(self, transform: Callable, target_root: str, image_base=4):
""" Translate an image and save it into a specified directory
Args:
transform (callable): a transform function that maps (image, label) pair from one domain to another domain
target_root (str): the root directory to save images and labels
"""
os.makedirs(target_root, exist_ok=True)
for path in tqdm.tqdm(self.samples):
image = Image.open(path).convert('RGB')
translated_path = path.replace(self.root, target_root)
ow, oh = image.size
image = make_power_2(image, image_base)
translated_image = transform(image)
translated_image = translated_image.resize((ow, oh))
os.makedirs(os.path.dirname(translated_path), exist_ok=True)
translated_image.save(translated_path)
def main(args):
logger = CompleteLogger(args.log, args.phase)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = T.Compose([
T.RandomRotation(args.rotation),
T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=args.resize_scale),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_source_dataset = build_dataset(args.source[::2], args.source[1::2], train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
train_target_dataset = build_dataset(args.target[::2], args.target[1::2], train_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# define networks (both generators and discriminators)
netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)
netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)
# create image buffer to store previously generated images
fake_S_pool = ImagePool(args.pool_size)
fake_T_pool = ImagePool(args.pool_size)
# define optimizer and lr scheduler
optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999))
optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999))
lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay)
lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)
lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)
# optionally resume from a checkpoint
if args.resume:
print("Resume from", args.resume)
checkpoint = torch.load(args.resume, map_location='cpu')
netG_S2T.load_state_dict(checkpoint['netG_S2T'])
netG_T2S.load_state_dict(checkpoint['netG_T2S'])
netD_S.load_state_dict(checkpoint['netD_S'])
netD_T.load_state_dict(checkpoint['netD_T'])
optimizer_G.load_state_dict(checkpoint['optimizer_G'])
optimizer_D.load_state_dict(checkpoint['optimizer_D'])
lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])
args.start_epoch = checkpoint['epoch'] + 1
if args.phase == 'train':
# define loss function
criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
# define visualization function
tensor_to_image = Compose([
Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ToPILImage()
])
def visualize(image, name):
"""
Args:
image (tensor): image in shape 3 x H x W
name: name of the saving image
"""
tensor_to_image(image).save(logger.get_image_path("{}.png".format(name)))
# start training
for epoch in range(args.start_epoch, args.epochs+args.epochs_decay):
logger.set_epoch(epoch)
print(lr_scheduler_G.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T,
criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D,
fake_S_pool, fake_T_pool, epoch, visualize, args)
# update learning rates
lr_scheduler_G.step()
lr_scheduler_D.step()
# save checkpoint
torch.save(
{
'netG_S2T': netG_S2T.state_dict(),
'netG_T2S': netG_T2S.state_dict(),
'netD_S': netD_S.state_dict(),
'netD_T': netD_T.state_dict(),
'optimizer_G': optimizer_G.state_dict(),
'optimizer_D': optimizer_D.state_dict(),
'lr_scheduler_G': lr_scheduler_G.state_dict(),
'lr_scheduler_D': lr_scheduler_D.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path('latest')
)
if args.translated_source is not None:
transform = cyclegan.transform.Translation(netG_S2T, device)
for dataset, translated_source in zip(train_source_dataset.datasets, args.translated_source):
dataset.translate(transform, translated_source, image_base=args.image_base)
if args.translated_target is not None:
transform = cyclegan.transform.Translation(netG_T2S, device)
for dataset, translated_target in zip(train_target_dataset.datasets, args.translated_target):
dataset.translate(transform, translated_target, image_base=args.image_base)
logger.close()
def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T,
criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D,
fake_S_pool, fake_T_pool, epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
losses_D_S = AverageMeter('D_S', ':3.2f')
losses_D_T = AverageMeter('D_T', ':3.2f')
losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
losses_identity_S = AverageMeter('idt_S', ':3.2f')
losses_identity_T = AverageMeter('idt_T', ':3.2f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T,
losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T],
prefix="Epoch: [{}]".format(epoch))
end = time.time()
for i in range(args.iters_per_epoch):
real_S, _ = next(train_source_iter)
real_T, _ = next(train_target_iter)
real_S = real_S.to(device)
real_T = real_T.to(device)
# measure data loading time
data_time.update(time.time() - end)
# Compute fake images and reconstruction images.
fake_T = netG_S2T(real_S)
rec_S = netG_T2S(fake_T)
fake_S = netG_T2S(real_T)
rec_T = netG_S2T(fake_S)
# Optimizing generators
# discriminators require no gradients
set_requires_grad(netD_S, False)
set_requires_grad(netD_T, False)
optimizer_G.zero_grad()
# GAN loss D_T(G_S2T(S))
loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
# GAN loss D_S(G_T2S(B))
loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
# Cycle loss || G_T2S(G_S2T(S)) - S||
loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle
# Cycle loss || G_S2T(G_T2S(T)) - T||
loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle
# Identity loss
# G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
identity_T = netG_S2T(real_T)
loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity
# G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
identity_S = netG_T2S(real_S)
loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity
# combined loss and calculate gradients
loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
loss_G.backward()
optimizer_G.step()
# Optimize discriminator
set_requires_grad(netD_S, True)
set_requires_grad(netD_T, True)
optimizer_D.zero_grad()
# Calculate GAN loss for discriminator D_S
fake_S_ = fake_S_pool.query(fake_S.detach())
loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False))
loss_D_S.backward()
# Calculate GAN loss for discriminator D_T
fake_T_ = fake_T_pool.query(fake_T.detach())
loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False))
loss_D_T.backward()
optimizer_D.step()
# measure elapsed time
losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
losses_D_S.update(loss_D_S.item(), real_S.size(0))
losses_D_T.update(loss_D_T.item(), real_S.size(0))
losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T],
["real_S", "real_T", "fake_S", "fake_T", "rec_S",
"rec_T", "identity_S", "identity_T"]):
visualize(tensor[0], "{}_{}".format(i, name))
def build_dataset(dataset_names, dataset_roots, transform):
"""
Give a sequence of dataset class name and a sequence of dataset root directory,
return a sequence of built datasets
"""
dataset_lists = []
for dataset_name, root in zip(dataset_names, dataset_roots):
if dataset_name in ["WaterColor", "Comic"]:
dataset = VOCImageFolder(root, phase='train', transform=transform)
elif dataset_name in ["Cityscapes", "FoggyCityscapes"]:
dataset = VOCImageFolder(root, phase="trainval", transform=transform, extension=".png")
elif dataset_name in ["Sim10k"]:
dataset = VOCImageFolder(root, phase="trainval10k", transform=transform)
else:
dataset = VOCImageFolder(root, phase="trainval", transform=transform)
dataset_lists.append(dataset)
return ConcatDataset(dataset_lists)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CycleGAN for Segmentation')
# dataset parameters
parser.add_argument('-s', '--source', nargs='+', help='source domain(s)')
parser.add_argument('-t', '--target', nargs='+', help='target domain(s)')
parser.add_argument('--rotation', type=int, default=0,
help='rotation range of the RandomRotation augmentation')
parser.add_argument('--resize-ratio', nargs='+', type=float, default=(0.5, 1.0),
help='the resize ratio for the random resize crop')
parser.add_argument('--resize-scale', nargs='+', type=float, default=(3./4., 4./3.),
help='the resize scale for the random resize crop')
parser.add_argument('--train-size', nargs='+', type=int, default=(512, 512),
help='the input and output image size during training')
# model parameters
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='patch',
help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.')
parser.add_argument('--netG', type=str, default='unet_256',
help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]')
parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument("--resume", type=str, default=None,
help="Where restore model parameters from.")
parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss')
parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N',
help='mini-batch size (default: 1)')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--epochs-decay', type=int, default=20,
help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int,
help='Number of iterations per epoch')
parser.add_argument('--pool-size', type=int, default=50,
help='the size of image buffer that stores previously generated images')
parser.add_argument('-p', '--print-freq', default=500, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='cyclegan',
help="Where to save logs, checkpoints and debugging images.")
# test parameters
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--test-input-size', nargs='+', type=int, default=(512, 512),
help='the input image size during test')
parser.add_argument('--translated-source', type=str, default=None, nargs='+',
help="The root to put the translated source dataset")
parser.add_argument('--translated-target', type=str, default=None, nargs='+',
help="The root to put the translated target dataset")
parser.add_argument('--image-base', default=4, type=int,
help='the input image will be multiple of image-base before translated')
args = parser.parse_args()
print(args)
main(args)
================================================
FILE: examples/domain_adaptation/object_detection/cycle_gan.sh
================================================
# VOC to Clipart
mkdir datasets/VOC2007_to_clipart
cp -r datasets/VOC2007/* datasets/VOC2007_to_clipart
mkdir datasets/VOC2012_to_clipart
cp -r datasets/VOC2012/* datasets/VOC2012_to_clipart
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py \
-s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \
--translated-source datasets/VOC2007_to_clipart datasets/VOC2012_to_clipart \
--log logs/cyclegan_resnet9/translation/voc2clipart --netG resnet_9
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 VOC2007 datasets/VOC2007_to_clipart VOC2012 datasets/VOC2012_to_clipart \
-t Clipart datasets/clipart \
--test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \
OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2clipart
# VOC to Comic
mkdir datasets/VOC2007_to_comic
cp -r datasets/VOC2007/* datasets/VOC2007_to_comic
mkdir datasets/VOC2012_to_comic
cp -r datasets/VOC2012/* datasets/VOC2012_to_comic
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py \
-s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Comic datasets/comic \
--translated-source datasets/VOC2007_to_comic datasets/VOC2012_to_comic \
--log logs/cyclegan_resnet9/translation/voc2comic --netG resnet_9
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 VOC2007Partial datasets/VOC2007_to_comic VOC2012Partial datasets/VOC2012_to_comic \
-t Comic datasets/comic \
--test VOC2007Test datasets/VOC2007 ComicTest datasets/comic --finetune \
OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2comic MODEL.ROI_HEADS.NUM_CLASSES 6
# VOC to WaterColor
mkdir datasets/VOC2007_to_watercolor
cp -r datasets/VOC2007/* datasets/VOC2007_to_watercolor
mkdir datasets/VOC2012_to_watercolor
cp -r datasets/VOC2012/* datasets/VOC2012_to_watercolor
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py \
-s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t WaterColor datasets/watercolor \
--translated-source datasets/VOC2007_to_watercolor datasets/VOC2012_to_watercolor \
--log logs/cyclegan_resnet9/translation/voc2watercolor --netG resnet_9
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 VOC2007Partial datasets/VOC2007_to_watercolor VOC2012Partial datasets/VOC2012_to_watercolor \
-t WaterColor datasets/watercolor \
--test VOC2007Test datasets/VOC2007 WaterColorTest datasets/watercolor --finetune \
OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2watercolor MODEL.ROI_HEADS.NUM_CLASSES 6
# Cityscapes to Foggy Cityscapes
mkdir datasets/cityscapes_to_foggy_cityscapes
cp -r datasets/cityscapes_in_voc/* datasets/cityscapes_to_foggy_cityscapes
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s Cityscapes datasets/cityscapes_in_voc \
-t FoggyCityscapes datasets/foggy_cityscapes_in_voc \
--translated-source datasets/cityscapes_to_foggy_cityscapes \
--log logs/cyclegan/translation/cityscapes2foggy
# ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s Cityscapes datasets/cityscapes_in_voc/ Cityscapes datasets/cityscapes_to_foggy_cityscapes/ \
-t FoggyCityscapes datasets/foggy_cityscapes_in_voc \
--test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \
OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/cityscapes2foggy
# VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s Cityscapes datasets/cityscapes_in_voc/ Cityscapes datasets/cityscapes_to_foggy_cityscapes/ \
-t FoggyCityscapes datasets/foggy_cityscapes_in_voc \
--test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \
OUTPUT_DIR logs/cyclegan/faster_rcnn_vgg_16/cityscapes2foggy
# Sim10k to Cityscapes Car
mkdir datasets/sim10k_to_cityscapes_car
cp -r datasets/sim10k/* datasets/sim10k_to_cityscapes_car
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s Sim10k datasets/sim10k -t Cityscapes datasets/cityscapes_in_voc \
--log logs/cyclegan/translation/sim10k2cityscapes_car --translated-source datasets/sim10k_to_cityscapes_car --image-base 256
# ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s Sim10kCar datasets/sim10k Sim10kCar datasets/sim10k_to_cityscapes_car -t CityscapesCar datasets/cityscapes_in_voc/ \
--test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \
OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1
# VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s Sim10kCar datasets/sim10k Sim10kCar datasets/sim10k_to_cityscapes_car -t CityscapesCar datasets/cityscapes_in_voc/ \
--test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \
OUTPUT_DIR logs/cyclegan/faster_rcnn_vgg_16/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1
# GTA5 to Cityscapes
mkdir datasets/gta5_to_cityscapes
cp -r datasets/synscapes_detection/* datasets/gta5_to_cityscapes
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s GTA5 datasets/synscapes_detection -t Cityscapes datasets/cityscapes_in_voc \
--log logs/cyclegan/translation/gta52cityscapes --translated-source datasets/gta5_to_cityscapes --image-base 256
# ResNet101 Based Faster RCNN: GTA5 -> Cityscapes
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s GTA5 datasets/synscapes_detection GTA5 datasets/gta5_to_cityscapes -t Cityscapes datasets/cityscapes_in_voc \
--test CityscapesTest datasets/cityscapes_in_voc/ --finetune \
OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/gta52cityscapes
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/README.md
================================================
# Decoupled Adaptation for Cross-Domain Object Detection
## Installation
Our code is based on
- [Detectron latest(v0.6)](https://detectron2.readthedocs.io/en/latest/tutorials/install.html)
- [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models)
please install them before usage.
## Method
Compared with previous cross-domain object detection methods, D-adapt decouples the adversarial adaptation from the training of detector.
The whole pipeline is as follows:
First, you need to run ``source_only.py`` to obtain pre-trained models. (See source_only.sh for scripts.)
Then you need to run ``d_adapt.py`` to obtain adapted models. (See d_adapt.sh for scripts).
When the domain discrepancy is large, you need to run ``d_adapt.py`` multiple times.
For better readability, we implement the training of category adaptor in ``category_adaptation.py``,
implement the training of the bounding box adaptor in``bbox_adaptation.py``,
and implement the training of the detector and connect the above components in ``d_adapt.py``.
This can facilitate you to modify and replace other adaptors.
We provide independent training arguments for detector, category adaptor and bounding box adaptor.
The arguments of latter two end with ``-c`` and ``-b`` respectively.
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{jiang2021decoupled,
title = {Decoupled Adaptation for Cross-Domain Object Detection},
author = {Junguang Jiang and Baixu Chen and Jianmin Wang and Mingsheng Long},
booktitle = {ICLR},
year = {2022}
}
```
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/bbox_adaptation.py
================================================
"""
Training a bounding box adaptor
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import os.path as osp
import argparse
from collections import deque
import tqdm
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
from detectron2.modeling.box_regression import Box2BoxTransform
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.modules.regressor import Regressor
from tllib.alignment.mdd import ImageRegressor, RegressionMarginDisparityDiscrepancy
from tllib.alignment.d_adapt.proposal import ProposalDataset, PersistentProposalList, flatten, ExpandCrop
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class BoxTransform(nn.Module):
def __init__(self):
super(BoxTransform, self).__init__()
BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
self.box_transform = Box2BoxTransform(weights=BBOX_REG_WEIGHTS)
def forward(self, pred_delta, gt_classes, proposal_boxes):
"""
Args:
- pred_delta: predicted bounding box offset for each classes
- gt_classes: ground truth classes
- proposal_boxes: referenced bounding box
Returns:
predicted bounding box offset for ground truth classes
and predicted bounding box
"""
gt_class_cols = 4 * gt_classes[:, None] + torch.arange(4, device=device)
pred_delta = torch.gather(pred_delta, dim=1, index=gt_class_cols)
pred_box = self.box_transform.apply_deltas(pred_delta, proposal_boxes)
return pred_delta, pred_box
def iou_between(
boxes1: torch.Tensor,
boxes2: torch.Tensor,
eps: float = 1e-7,
reduction: str = "none"
):
"""Intersections over Union between two boxes"""
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
assert (x2 >= x1).all(), "bad box: x1 larger than x2"
assert (y2 >= y1).all(), "bad box: y1 larger than y2"
# Intersection keypoints
xkis1 = torch.max(x1, x1g)
ykis1 = torch.max(y1, y1g)
xkis2 = torch.min(x2, x2g)
ykis2 = torch.min(y2, y2g)
intsctk = torch.zeros_like(x1)
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
iouk = intsctk / (unionk + eps)
if reduction == 'mean':
return iouk.mean()
elif reduction == 'sum':
return iouk.sum()
else:
return iouk
def clamp_single(box, w, h):
x1, y1, x2, y2 = box
x1 = x1.clamp(min=0, max=w)
x2 = x2.clamp(min=0, max=w)
y1 = y1.clamp(min=0, max=h)
y2 = y2.clamp(min=0, max=h)
return torch.tensor((x1, y1, x2, y2))
def clamp(boxes, widths, heights):
"""clamp (limit) the values in boxes within the widths and heights of the image."""
clamped_boxes = []
for box, w, h in zip(boxes, widths, heights):
clamped_boxes.append(clamp_single(box, w, h))
return torch.stack(clamped_boxes, dim=0)
class BoundingBoxAdaptor:
def __init__(self, class_names, log, args):
self.class_names = class_names
for k, v in args._get_kwargs():
setattr(args, k.replace("_b", ""), v)
self.args = args
print(self.args)
self.logger = CompleteLogger(log)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
num_classes = len(class_names)
bottleneck_dim = args.bottleneck_dim
bottleneck = nn.Sequential(
nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(bottleneck_dim),
nn.ReLU(),
)
head = nn.Sequential(
nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(bottleneck_dim),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(bottleneck_dim, num_classes * 4),
)
for layer in head:
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, 0, 0.01)
nn.init.constant_(layer.bias, 0)
adv_head = nn.Sequential(
nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(bottleneck_dim),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(bottleneck_dim, num_classes * 4),
)
for layer in adv_head:
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, 0, 0.01)
nn.init.constant_(layer.bias, 0)
self.model = ImageRegressor(
backbone, num_classes * 4, bottleneck=bottleneck,
head=head, adv_head=adv_head
).to(device)
self.box_transform = BoxTransform()
def load_checkpoint(self, path=None):
if path is None:
path = self.logger.get_checkpoint_path('latest')
if osp.exists(path):
checkpoint = torch.load(path, map_location='cpu')
self.model.load_state_dict(checkpoint)
return True
else:
return False
def prepare_training_data(self, proposal_list: PersistentProposalList, labeled=True):
if not labeled:
# remove (predicted) background proposals
filtered_proposals_list = []
for proposals in proposal_list:
keep_indices = (0 <= proposals.pred_classes) & (proposals.pred_classes < len(self.class_names))
filtered_proposals_list.append(proposals[keep_indices])
else:
# remove proposals with low IoU
filtered_proposals_list = []
for proposals in proposal_list:
keep_indices = proposals.gt_ious > 0.3
filtered_proposals_list.append(proposals[keep_indices])
filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_train)
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = T.Compose([
T.Resize((self.args.resize_size, self.args.resize_size)),
# T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
# T.RandomGrayscale(),
T.ToTensor(),
normalize
])
dataset = ProposalDataset(filtered_proposals_list, transform, crop_func=ExpandCrop(self.args.expand))
dataloader = DataLoader(dataset, batch_size=self.args.batch_size,
shuffle=True, num_workers=self.args.workers, drop_last=True)
return dataloader
def prepare_validation_data(self, proposal_list: PersistentProposalList):
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = T.Compose([
T.Resize((self.args.resize_size, self.args.resize_size)),
T.ToTensor(),
normalize
])
# remove (predicted) background proposals
filtered_proposals_list = []
for proposals in proposal_list:
# keep_indices = (0 <= proposals.gt_classes) & (proposals.gt_classes < len(self.class_names))
keep_indices = (0 <= proposals.pred_classes) & (proposals.pred_classes < len(self.class_names))
filtered_proposals_list.append(proposals[keep_indices])
filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_val)
dataset = ProposalDataset(filtered_proposals_list, transform, crop_func=ExpandCrop(self.args.expand))
dataloader = DataLoader(dataset, batch_size=self.args.batch_size,
shuffle=False, num_workers=self.args.workers, drop_last=False)
return dataloader
def prepare_test_data(self, proposal_list: PersistentProposalList):
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = T.Compose([
T.Resize((self.args.resize_size, self.args.resize_size)),
T.ToTensor(),
normalize
])
dataset = ProposalDataset(proposal_list, transform, crop_func=ExpandCrop(self.args.expand))
dataloader = DataLoader(dataset, batch_size=self.args.batch_size,
shuffle=False, num_workers=self.args.workers, drop_last=False)
return dataloader
def predict(self, data_loader):
# switch to evaluate mode
self.model.eval()
predictions = deque()
with torch.no_grad():
for images, labels in tqdm.tqdm(data_loader):
images = images.to(device)
pred_classes = labels['pred_classes'].to(device)
pred_boxes = labels['pred_boxes'].to(device).float()
# compute output
pred_deltas = self.model(images)
_, pred_boxes = self.box_transform(pred_deltas, pred_classes, pred_boxes)
pred_boxes = clamp(pred_boxes.cpu(), labels['width'], labels['height'])
pred_boxes = pred_boxes.numpy().tolist()
for p in pred_boxes:
predictions.append(p)
return predictions
def validate_baseline(self, val_loader):
"""call this function if you have labeled data for validation"""
ious = AverageMeter("IoU", ":.4e")
print("Calculate baseline IoU:")
for _, labels in tqdm.tqdm(val_loader):
gt_boxes = labels['gt_boxes']
pred_boxes = labels['pred_boxes']
ious.update(iou_between(pred_boxes, gt_boxes).mean().item(), gt_boxes.size(0))
print(' * Baseline IoU {:.3f}'.format(ious.avg))
return ious.avg
@staticmethod
def validate(val_loader, model, box_transform, args) -> float:
"""call this function if you have labeled data for validation"""
batch_time = AverageMeter('Time', ':6.3f')
ious = AverageMeter("IoU", ":.4e")
progress = ProgressMeter(
len(val_loader),
[batch_time, ious],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, labels) in enumerate(val_loader):
images = images.to(device)
pred_classes = labels['pred_classes'].to(device)
gt_boxes = labels['gt_boxes'].to(device).float()
pred_boxes = labels['pred_boxes'].to(device).float()
# compute output
pred_deltas = model(images)
_, pred_boxes = box_transform(pred_deltas, pred_classes, pred_boxes)
pred_boxes = clamp(pred_boxes.cpu(), labels['width'], labels['height'])
ious.update(iou_between(pred_boxes, gt_boxes.cpu()).mean().item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
print(' * IoU {:.3f}'.format(ious.avg))
return ious.avg
def fit(self, data_loader_source, data_loader_target, data_loader_validation=None):
"""When no labels exists on target domain, please set data_loader_validation=None"""
args = self.args
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
iter_source = ForeverDataIterator(data_loader_source)
iter_target = ForeverDataIterator(data_loader_target)
best_iou = 0.
box_transform = self.box_transform
# first pre-train on the source domain
model = Regressor(
self.model.backbone, len(self.class_names) * 4,
bottleneck=nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
),
head=nn.Linear(self.model.backbone.out_features, len(self.class_names) * 4),
bottleneck_dim=self.model.backbone.out_features
).to(device)
optimizer = Adam(model.get_parameters(), args.pretrain_lr, weight_decay=args.pretrain_weight_decay)
lr_scheduler = LambdaLR(optimizer, lambda x: args.pretrain_lr * (1. + args.pretrain_lr_gamma * float(x)) ** (-args.pretrain_lr_decay))
for epoch in range(args.pretrain_epochs):
print("lr:", lr_scheduler.get_last_lr()[0])
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
ious = AverageMeter("IoU", ":.4e")
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, ious],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(iter_source)
x_s = x_s.to(device)
# bounding box offsets
delta_s = box_transform.box_transform.get_deltas(labels_s['pred_boxes'], labels_s['gt_boxes']).to(device).float()
pred_boxes_s = labels_s['pred_boxes'].to(device).float()
gt_classes_s = labels_s['gt_fg_classes'].to(device)
gt_boxes_s = labels_s['gt_boxes'].to(device).float()
# measure data loading time
data_time.update(time.time() - end)
# compute output
pred_delta_s, _ = model(x_s)
pred_delta_s, pred_boxes_s = box_transform(pred_delta_s, gt_classes_s, pred_boxes_s)
reg_loss = F.smooth_l1_loss(pred_delta_s, delta_s)
loss = reg_loss
losses.update(loss.item(), x_s.size(0))
ious.update(iou_between(pred_boxes_s.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
# evaluate on validation set
if data_loader_validation is not None:
iou = self.validate(data_loader_validation, model, box_transform, args)
best_iou = max(iou, best_iou)
# training on both domains
model = self.model
optimizer = SGD(model.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
for epoch in range(args.epochs):
print("lr:", lr_scheduler.get_last_lr()[0])
# train for one epoch
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
ious = AverageMeter("IoU", ":.4e")
ious_t = AverageMeter("IoU (t)", ":.4e")
ious_s_adv = AverageMeter("IoU (s, adv)", ":.4e")
ious_t_adv = AverageMeter("IoU (t, adv)", ":.4e")
trans_losses = AverageMeter('Trans Loss', ':3.2f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, ious, ious_t, ious_s_adv, ious_t_adv],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
mdd = RegressionMarginDisparityDiscrepancy(args.margin).to(device)
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(iter_source)
x_t, labels_t = next(iter_target)
x_s = x_s.to(device)
x_t = x_t.to(device)
# bounding box offsets
delta_s = box_transform.box_transform.get_deltas(labels_s['pred_boxes'], labels_s['gt_boxes']).to(device).float()
pred_boxes_s = labels_s['pred_boxes'].to(device).float()
gt_classes_s = labels_s['gt_fg_classes'].to(device)
gt_boxes_s = labels_s['gt_boxes'].to(device).float()
pred_boxes_t = labels_t['pred_boxes'].to(device).float()
gt_classes_t = labels_t['pred_classes'].to(device)
gt_boxes_t = labels_t['gt_boxes'].to(device).float()
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat([x_s, x_t], dim=0)
outputs, outputs_adv = model(x)
pred_delta_s, pred_delta_t = outputs.chunk(2, dim=0)
pred_delta_s_adv, pred_delta_t_adv = outputs_adv.chunk(2, dim=0)
pred_delta_s, pred_boxes_s = box_transform(pred_delta_s, gt_classes_s, pred_boxes_s)
pred_delta_t, pred_boxes_t = box_transform(pred_delta_t, gt_classes_t, pred_boxes_t)
pred_delta_s_adv, pred_boxes_s_adv = box_transform(pred_delta_s_adv, gt_classes_s, pred_boxes_s)
pred_delta_t_adv, pred_boxes_t_adv = box_transform(pred_delta_t_adv, gt_classes_t, pred_boxes_t)
reg_loss = F.smooth_l1_loss(pred_delta_s, delta_s)
# compute margin disparity discrepancy between domains
transfer_loss = mdd(pred_delta_s, pred_delta_s_adv, pred_delta_t, pred_delta_t_adv)
# for adversarial classifier, minimize negative mdd is equal to maximize mdd
loss = reg_loss - transfer_loss * args.trade_off
model.step()
losses.update(loss.item(), x_s.size(0))
ious.update(iou_between(pred_boxes_s.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0))
ious_t.update(iou_between(pred_boxes_t.cpu(), gt_boxes_t.cpu()).mean().item(), x_s.size(0))
ious_s_adv.update(iou_between(pred_boxes_s_adv.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0))
ious_t_adv.update(iou_between(pred_boxes_t_adv.cpu(), gt_boxes_t.cpu()).mean().item(), x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
# evaluate on validation set
if data_loader_validation is not None:
iou = self.validate(data_loader_validation, model, box_transform, args)
best_iou = max(iou, best_iou)
# save checkpoint
torch.save(model.state_dict(), self.logger.get_checkpoint_path('latest'))
print("best_iou = {:3.1f}".format(best_iou))
self.logger.logger.flush()
@staticmethod
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(add_help=False)
# dataset parameters
parser.add_argument('--resize-size-b', type=int, default=224,
help='the image size after resizing')
parser.add_argument('--max-train-b', type=int, default=10)
parser.add_argument('--max-val-b', type=int, default=10)
parser.add_argument('--expand-b', type=float, default=2.,
help='The expanding ratio between the input of the bounding box adaptor'
'(the crops of objects) and the the original predicted box.')
# model parameters
parser.add_argument('--arch-b', metavar='ARCH', default='resnet101',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet101)')
parser.add_argument('--bottleneck-dim-b', default=1024, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool-b', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch-b', action='store_true', help='whether train from scratch.')
parser.add_argument('--margin', type=float, default=4., help="margin hyper-parameter")
parser.add_argument('--trade-off', default=0.1, type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('--batch-size-b', default=32, type=int,
metavar='N',
help='mini-batch size (default: 64)')
parser.add_argument('--lr-b', default=0.004, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--lr-gamma-b', default=0.0002, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-b', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--weight-decay-b', default=5e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--workers-b', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs-b', default=2, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--pretrain-lr-b', default=0.001, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--pretrain-lr-gamma-b', default=0.0002, type=float, help='parameter for lr scheduler')
parser.add_argument('--pretrain-lr-decay-b', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--pretrain-weight-decay-b', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)')
parser.add_argument('--pretrain-epochs-b', default=10, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--iters-per-epoch-b', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('--print-freq-b', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed-b', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log-b", type=str, default='box',
help="Where to save logs, checkpoints and debugging images.")
return parser
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/category_adaptation.py
================================================
"""
Training a category adaptor
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import os.path as osp
from collections import deque
import tqdm
from typing import List
import torch
from torch import Tensor
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
sys.path.append('../../../..')
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier
from tllib.alignment.d_adapt.proposal import ProposalDataset, flatten, Proposal
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy, ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.vision.transforms import ResizeImage
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ConfidenceBasedDataSelector:
"""Select data point based on confidence"""
def __init__(self, confidence_ratio=0.1, category_names=()):
self.confidence_ratio = confidence_ratio
self.categories = []
self.scores = []
self.category_names = category_names
self.per_category_thresholds = None
def extend(self, categories, scores):
self.categories.extend(categories)
self.scores.extend(scores)
def calculate(self):
per_category_scores = {c: [] for c in self.category_names}
for c, s in zip(self.categories, self.scores):
per_category_scores[c].append(s)
per_category_thresholds = {}
print(per_category_scores.keys())
for c, s in per_category_scores.items():
s.sort(reverse=True)
print(c, len(s), int(self.confidence_ratio * len(s)))
per_category_thresholds[c] = s[int(self.confidence_ratio * len(s))] if len(s) else 1.
print('----------------------------------------------------')
print("confidence threshold for each category:")
for c in self.category_names:
print('\t', c, round(per_category_thresholds[c], 3))
print('----------------------------------------------------')
self.per_category_thresholds = per_category_thresholds
def whether_select(self, categories, scores):
assert self.per_category_thresholds is not None, "please call calculate before selection!"
return [s > self.per_category_thresholds[c] for c, s in zip(categories, scores)]
class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
"""Cross-entropy that's robust to label noise"""
def __init__(self, *args, offset=0.1, **kwargs):
self.offset = offset
super(RobustCrossEntropyLoss, self).__init__(*args, **kwargs)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return F.cross_entropy(torch.clamp(input + self.offset, max=1.), target, weight=self.weight,
ignore_index=self.ignore_index, reduction='sum') / input.shape[0]
class CategoryAdaptor:
def __init__(self, class_names, log, args):
self.class_names = class_names
for k, v in args._get_kwargs():
setattr(args, k.rstrip("_c"), v)
self.args = args
print(self.args)
self.logger = CompleteLogger(log)
self.selector = ConfidenceBasedDataSelector(self.args.confidence_ratio, range(len(self.class_names) + 1))
# create model
print("=> using model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
num_classes = len(self.class_names) + 1
self.model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch).to(device)
def load_checkpoint(self):
if osp.exists(self.logger.get_checkpoint_path('latest')):
checkpoint = torch.load(self.logger.get_checkpoint_path('latest'), map_location='cpu')
self.model.load_state_dict(checkpoint)
return True
else:
return False
def prepare_training_data(self, proposal_list: List[Proposal], labeled=True):
if not labeled:
# remove proposals with confidence score between (ignored_scores[0], ignored_scores[1])
filtered_proposals_list = []
assert len(self.args.ignored_scores) == 2 and self.args.ignored_scores[0] <= self.args.ignored_scores[1], \
"Please provide a range for ignored_scores!"
for proposals in proposal_list:
keep_indices = ~((self.args.ignored_scores[0] < proposals.pred_scores)
& (proposals.pred_scores < self.args.ignored_scores[1]))
filtered_proposals_list.append(proposals[keep_indices])
# calculate confidence threshold for each cateogry on the target domain
for proposals in filtered_proposals_list:
self.selector.extend(proposals.pred_classes.tolist(), proposals.pred_scores.tolist())
self.selector.calculate()
else:
# remove proposals with ignored classes or ious between (ignored_ious[0], ignored_ious[1])
filtered_proposals_list = []
for proposals in proposal_list:
keep_indices = (proposals.gt_classes != -1) & \
~((self.args.ignored_ious[0] < proposals.gt_ious) &
(proposals.gt_ious < self.args.ignored_ious[1]))
filtered_proposals_list.append(proposals[keep_indices])
filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_train)
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = T.Compose([
ResizeImage(self.args.resize_size),
T.RandomHorizontalFlip(),
T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5),
T.RandomGrayscale(),
T.ToTensor(),
normalize
])
dataset = ProposalDataset(filtered_proposals_list, transform)
dataloader = DataLoader(dataset, batch_size=self.args.batch_size,
shuffle=True, num_workers=self.args.workers, drop_last=True)
return dataloader
def prepare_validation_data(self, proposal_list: List[Proposal]):
"""call this function if you have labeled data for validation"""
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = T.Compose([
ResizeImage(self.args.resize_size),
T.ToTensor(),
normalize
])
# remove proposals with ignored classes
filtered_proposals_list = []
for proposals in proposal_list:
keep_indices = proposals.gt_classes != -1
filtered_proposals_list.append(proposals[keep_indices])
filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_val)
dataset = ProposalDataset(filtered_proposals_list, transform)
dataloader = DataLoader(dataset, batch_size=self.args.batch_size,
shuffle=False, num_workers=self.args.workers, drop_last=False)
return dataloader
def prepare_test_data(self, proposal_list: List[Proposal]):
normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = T.Compose([
ResizeImage(self.args.resize_size),
T.ToTensor(),
normalize
])
dataset = ProposalDataset(proposal_list, transform)
dataloader = DataLoader(dataset, batch_size=self.args.batch_size,
shuffle=False, num_workers=self.args.workers, drop_last=False)
return dataloader
def fit(self, data_loader_source, data_loader_target, data_loader_validation=None):
"""When no labels exists on target domain, please set data_loader_validation=None"""
args = self.args
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
iter_source = ForeverDataIterator(data_loader_source)
iter_target = ForeverDataIterator(data_loader_target)
model = self.model
feature_dim = model.features_dim
num_classes = len(self.class_names) + 1
if args.randomized:
domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024).to(device)
else:
domain_discri = DomainDiscriminator(feature_dim * num_classes, hidden_size=1024).to(device)
all_parameters = model.get_parameters() + domain_discri.get_parameters()
# define optimizer and lr scheduler
optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv = ConditionalDomainAdversarialLoss(
domain_discri, entropy_conditioning=args.entropy,
num_classes=num_classes, features_dim=feature_dim, randomized=args.randomized,
randomized_dim=args.randomized_dim
).to(device)
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
print("lr:", lr_scheduler.get_last_lr()[0])
# train for one epoch
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
losses_t = AverageMeter('Loss(t)', ':3.2f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, losses_t, trans_losses, cls_accs, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(iter_source)
x_t, labels_t = next(iter_target)
# assign pseudo labels for target-domain proposals with extremely high confidence
selected = torch.tensor(
self.selector.whether_select(
labels_t['pred_classes'].numpy().tolist(),
labels_t['pred_scores'].numpy().tolist()
)
)
pseudo_classes_t = selected * labels_t['pred_classes'] + (~selected) * -1
pseudo_classes_t = pseudo_classes_t.to(device)
x_s = x_s.to(device)
x_t = x_t.to(device)
gt_classes_s = labels_s['gt_classes'].to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, gt_classes_s, ignore_index=-1)
cls_loss_t = RobustCrossEntropyLoss(ignore_index=-1, offset=args.epsilon)(y_t, pseudo_classes_t)
transfer_loss = domain_adv(y_s, f_s, y_t, f_t)
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off + cls_loss_t
cls_acc = accuracy(y_s, gt_classes_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc, x_s.size(0))
domain_accs.update(domain_acc, x_s.size(0))
trans_losses.update(transfer_loss.item(), x_s.size(0))
losses_t.update(cls_loss_t.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
# evaluate on validation set
if data_loader_validation is not None:
acc1 = self.validate(data_loader_validation, model, self.class_names, args)
best_acc1 = max(acc1, best_acc1)
# save checkpoint
torch.save(model.state_dict(), self.logger.get_checkpoint_path('latest'))
print("best_acc1 = {:3.1f}".format(best_acc1))
domain_adv.to(torch.device("cpu"))
self.logger.logger.flush()
def predict(self, data_loader):
# switch to evaluate mode
self.model.eval()
predictions = deque()
with torch.no_grad():
for images, _ in tqdm.tqdm(data_loader):
images = images.to(device)
# compute output
output = self.model(images)
prediction = output.argmax(-1).cpu().numpy().tolist()
for p in prediction:
predictions.append(p)
return predictions
@staticmethod
def validate(val_loader, model, class_names, args) -> float:
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1],
prefix='Test: ')
# switch to evaluate mode
model.eval()
confmat = ConfusionMatrix(len(class_names)+1)
with torch.no_grad():
end = time.time()
for i, (images, labels) in enumerate(val_loader):
images = images.to(device)
gt_classes = labels['gt_classes'].to(device)
# compute output
output = model(images)
loss = F.cross_entropy(output, gt_classes)
# measure accuracy and record loss
acc1, = accuracy(output, gt_classes, topk=(1,))
confmat.update(gt_classes, output.argmax(1))
losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
print(confmat.format(class_names+["bg"]))
return top1.avg
@staticmethod
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(add_help=False)
# dataset parameters
parser.add_argument('--resize-size-c', type=int, default=112,
help='the image size after resizing')
parser.add_argument('--ignored-scores-c', type=float, nargs='+', default=[0.05, 0.3])
parser.add_argument('--max-train-c', type=int, default=10)
parser.add_argument('--max-val-c', type=int, default=2)
parser.add_argument('--ignored-ious-c', type=float, nargs='+', default=(0.4, 0.5),
help='the iou threshold for ignored boxes')
# model parameters
parser.add_argument('--arch-c', metavar='ARCH', default='resnet101',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet101)')
parser.add_argument('--bottleneck-dim-c', default=1024, type=int,
help='Dimension of bottleneck')
parser.add_argument('--no-pool-c', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch-c', action='store_true', help='whether train from scratch.')
parser.add_argument('--randomized-c', action='store_true',
help='using randomized multi-linear-map (default: False)')
parser.add_argument('--randomized-dim-c', default=1024, type=int,
help='randomized dimension when using randomized multi-linear-map (default: 1024)')
parser.add_argument('--entropy-c', default=False, action='store_true', help='use entropy conditioning')
parser.add_argument('--trade-off-c', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
parser.add_argument('--confidence-ratio-c', default=0.0, type=float)
parser.add_argument('--epsilon-c', default=0.01, type=float,
help='epsilon hyper-parameter in Robust Cross Entropy')
# training parameters
parser.add_argument('--batch-size-c', default=64, type=int,
metavar='N',
help='mini-batch size (default: 64)')
parser.add_argument('--learning-rate-c', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma-c', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-c', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum-c', default=0.9, type=float, metavar='M', help='momentum')
parser.add_argument('--weight-decay-c', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('--workers-c', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs-c', default=10, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--iters-per-epoch-c', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('--print-freq-c', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed-c', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log-c", type=str, default='cdan',
help="Where to save logs, checkpoints and debugging images.")
return parser
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_R_101_C4_cityscapes.yaml
================================================
MODEL:
META_ARCHITECTURE: "DecoupledGeneralizedRCNN"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
MASK_ON: False
RESNETS:
DEPTH: 101
ROI_HEADS:
NAME: "DecoupledRes5ROIHeads"
NUM_CLASSES: 8
BATCH_SIZE_PER_IMAGE: 512
ANCHOR_GENERATOR:
SIZES: [ [ 64, 128, 256, 512 ] ]
RPN:
PRE_NMS_TOPK_TEST: 6000
POST_NMS_TOPK_TEST: 1000
BATCH_SIZE_PER_IMAGE: 256
PROPOSAL_GENERATOR:
NAME: "TLRPN"
INPUT:
MIN_SIZE_TRAIN: (512, 544, 576, 608, 640, 672, 704,)
MIN_SIZE_TEST: 800
MAX_SIZE_TRAIN: 1166
DATASETS:
TRAIN: ("cityscapes_trainval",)
TEST: ("cityscapes_test",)
SOLVER:
STEPS: (3999, )
MAX_ITER: 4000 # 4 epochs
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 1000
IMS_PER_BATCH: 2
BASE_LR: 0.005
LR_SCHEDULER_NAME: "ExponentialLR"
GAMMA: 0.1
TEST:
EVAL_PERIOD: 500
VIS_PERIOD: 20
VERSION: 2
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_R_101_C4_voc.yaml
================================================
MODEL:
META_ARCHITECTURE: "DecoupledGeneralizedRCNN"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
MASK_ON: False
RESNETS:
DEPTH: 101
ROI_HEADS:
NAME: "DecoupledRes5ROIHeads"
NUM_CLASSES: 20
BATCH_SIZE_PER_IMAGE: 256
ANCHOR_GENERATOR:
SIZES: [ [ 64, 128, 256, 512 ] ]
RPN:
PRE_NMS_TOPK_TEST: 6000
POST_NMS_TOPK_TEST: 1000
BATCH_SIZE_PER_IMAGE: 128
PROPOSAL_GENERATOR:
NAME: "TLRPN"
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704,)
MIN_SIZE_TEST: 608
MAX_SIZE_TRAIN: 1166
DATASETS:
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
TEST: ('voc_2007_test',)
SOLVER:
STEPS: (3999, )
MAX_ITER: 4000 # 16 epochs
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 1000
IMS_PER_BATCH: 4
BASE_LR: 0.00025
LR_SCHEDULER_NAME: "ExponentialLR"
GAMMA: 0.1
TEST:
EVAL_PERIOD: 500
VIS_PERIOD: 20
VERSION: 2
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_vgg_16_cityscapes.yaml
================================================
MODEL:
META_ARCHITECTURE: "DecoupledGeneralizedRCNN"
WEIGHTS: 'https://open-mmlab.oss-cn-beijing.aliyuncs.com/pretrain/vgg16_caffe-292e1171.pth'
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_ON: False
BACKBONE:
NAME: "build_vgg_fpn_backbone"
ROI_HEADS:
IN_FEATURES: ["p3", "p4", "p5", "p6"]
NAME: "DecoupledStandardROIHeads"
NUM_CLASSES: 8
ROI_BOX_HEAD:
NAME: "FastRCNNConvFCHead"
NUM_FC: 2
POOLER_RESOLUTION: 7
ANCHOR_GENERATOR:
SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ] # One size for each in feature map
ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ] # Three aspect
RPN:
IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"]
PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
PRE_NMS_TOPK_TEST: 1000 # Per FPN level
POST_NMS_TOPK_TRAIN: 1000
POST_NMS_TOPK_TEST: 1000
PROPOSAL_GENERATOR:
NAME: "TLRPN"
INPUT:
FORMAT: "RGB"
MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800)
MIN_SIZE_TEST: 800
MAX_SIZE_TEST: 1280
MAX_SIZE_TRAIN: 1280
DATASETS:
TRAIN: ("cityscapes_trainval",)
TEST: ("cityscapes_test",)
SOLVER:
STEPS: (3999, )
MAX_ITER: 4000 # 4 epochs
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 1000
IMS_PER_BATCH: 8
BASE_LR: 0.01
LR_SCHEDULER_NAME: "ExponentialLR"
GAMMA: 0.1
TEST:
EVAL_PERIOD: 500
VIS_PERIOD: 20
VERSION: 2
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/config/retinanet_R_101_FPN_voc.yaml
================================================
MODEL:
META_ARCHITECTURE: "DecoupledRetinaNet"
WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl"
BACKBONE:
NAME: "build_retinanet_resnet_fpn_backbone"
MASK_ON: False
RESNETS:
DEPTH: 101
OUT_FEATURES: [ "res4", "res5" ]
ANCHOR_GENERATOR:
SIZES: !!python/object/apply:eval [ "[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [64, 128, 256, 512 ]]" ]
RETINANET:
NUM_CLASSES: 20
IN_FEATURES: ["p4", "p5", "p6", "p7"]
FPN:
IN_FEATURES: ["res4", "res5"]
INPUT:
MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, )
MIN_SIZE_TEST: 608
MAX_SIZE_TRAIN: 1166
DATASETS:
TRAIN: ('voc_2007_trainval', 'voc_2012_trainval')
TEST: ('voc_2007_test',)
SOLVER:
STEPS: (3999, )
MAX_ITER: 4000 # 16 epochs
WARMUP_ITERS: 100
CHECKPOINT_PERIOD: 1000
IMS_PER_BATCH: 8
BASE_LR: 0.001
TEST:
EVAL_PERIOD: 500
VIS_PERIOD: 20
VERSION: 2
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/d_adapt.py
================================================
"""
`D-adapt: Decoupled Adaptation for Cross-Domain Object Detection `_.
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import logging
import os
import argparse
import sys
import pprint
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel
from detectron2.engine import default_writers, launch
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
import detectron2.utils.comm as comm
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
from detectron2.data import (
build_detection_train_loader,
build_detection_test_loader,
MetadataCatalog
)
from detectron2.utils.events import EventStorage
from detectron2.evaluation import inference_on_dataset
sys.path.append('../../../..')
import tllib.alignment.d_adapt.modeling.meta_arch as models
from tllib.alignment.d_adapt.proposal import ProposalGenerator, ProposalMapper, PersistentProposalList, flatten
from tllib.alignment.d_adapt.feedback import get_detection_dataset_dicts, DatasetMapper
sys.path.append('..')
import utils
import category_adaptation
import bbox_adaptation
def generate_proposals(model, num_classes, dataset_names, cache_root, cfg):
"""Generate foreground proposals and background proposals from `model` and save them to the disk"""
fg_proposals_list = PersistentProposalList(os.path.join(cache_root, "{}_fg.json".format(dataset_names[0])))
bg_proposals_list = PersistentProposalList(os.path.join(cache_root, "{}_bg.json".format(dataset_names[0])))
if not (fg_proposals_list.load() and bg_proposals_list.load()):
for dataset_name in dataset_names:
data_loader = build_detection_test_loader(cfg, dataset_name, mapper=ProposalMapper(cfg, False))
generator = ProposalGenerator(num_classes=num_classes)
fg_proposals_list_data, bg_proposals_list_data = inference_on_dataset(model, data_loader, generator)
fg_proposals_list.extend(fg_proposals_list_data)
bg_proposals_list.extend(bg_proposals_list_data)
fg_proposals_list.flush()
bg_proposals_list.flush()
return fg_proposals_list, bg_proposals_list
def generate_category_labels(prop, category_adaptor, cache_filename):
"""Generate category labels for each proposals in `prop` and save them to the disk"""
prop_w_category = PersistentProposalList(cache_filename)
if not prop_w_category.load():
for p in prop:
prop_w_category.append(p)
data_loader_test = category_adaptor.prepare_test_data(flatten(prop_w_category))
predictions = category_adaptor.predict(data_loader_test)
for p in prop_w_category:
p.pred_classes = np.array([predictions.popleft() for _ in range(len(p))])
prop_w_category.flush()
return prop_w_category
def generate_bounding_box_labels(prop, bbox_adaptor, class_names, cache_filename):
"""Generate bounding box labels for each proposals in `prop` and save them to the disk"""
prop_w_bbox = PersistentProposalList(cache_filename)
if not prop_w_bbox.load():
# remove (predicted) background proposals
for p in prop:
keep_indices = (0 <= p.pred_classes) & (p.pred_classes < len(class_names))
prop_w_bbox.append(p[keep_indices])
data_loader_test = bbox_adaptor.prepare_test_data(flatten(prop_w_bbox))
predictions = bbox_adaptor.predict(data_loader_test)
for p in prop_w_bbox:
p.pred_boxes = np.array([predictions.popleft() for _ in range(len(p))])
prop_w_bbox.flush()
return prop_w_bbox
def train(model, logger, cfg, args, args_cls, args_box):
model.train()
distributed = comm.get_world_size() > 1
if distributed:
model_without_parallel = model.module
else:
model_without_parallel = model
# define optimizer and lr scheduler
params = []
for module, lr in model_without_parallel.get_parameters(cfg.SOLVER.BASE_LR):
params.extend(
get_default_optimizer_params(
module,
base_lr=lr,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
)
)
optimizer = maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params,
lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
scheduler = utils.build_lr_scheduler(cfg, optimizer)
# resume from the last checkpoint
checkpointer = DetectionCheckpointer(
model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
)
checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume)
start_iter = 0
max_iter = cfg.SOLVER.MAX_ITER
periodic_checkpointer = PeriodicCheckpointer(
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)
writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else []
# generate proposals from detector
classes = MetadataCatalog.get(args.targets[0]).thing_classes
cache_proposal_root = os.path.join(cfg.OUTPUT_DIR, "cache", "proposal")
prop_t_fg, prop_t_bg = generate_proposals(model, len(classes), args.targets, cache_proposal_root, cfg)
prop_s_fg, prop_s_bg = generate_proposals(model, len(classes), args.sources, cache_proposal_root, cfg)
model = model.to(torch.device('cpu'))
# train the category adaptor
category_adaptor = category_adaptation.CategoryAdaptor(classes, os.path.join(cfg.OUTPUT_DIR, "cls"), args_cls)
if not category_adaptor.load_checkpoint():
data_loader_source = category_adaptor.prepare_training_data(prop_s_fg + prop_s_bg, True)
data_loader_target = category_adaptor.prepare_training_data(prop_t_fg + prop_t_bg, False)
data_loader_validation = category_adaptor.prepare_validation_data(prop_t_fg + prop_t_bg)
category_adaptor.fit(data_loader_source, data_loader_target, data_loader_validation)
# generate category labels for each proposals
cache_feedback_root = os.path.join(cfg.OUTPUT_DIR, "cache", "feedback")
prop_t_fg = generate_category_labels(
prop_t_fg, category_adaptor, os.path.join(cache_feedback_root, "{}_fg.json".format(args.targets[0]))
)
prop_t_bg = generate_category_labels(
prop_t_bg, category_adaptor, os.path.join(cache_feedback_root, "{}_bg.json".format(args.targets[0]))
)
category_adaptor.model.to(torch.device("cpu"))
if args.bbox_refine:
# train the bbox adaptor
bbox_adaptor = bbox_adaptation.BoundingBoxAdaptor(classes, os.path.join(cfg.OUTPUT_DIR, "bbox"), args_box)
if not bbox_adaptor.load_checkpoint():
data_loader_source = bbox_adaptor.prepare_training_data(prop_s_fg, True)
data_loader_target = bbox_adaptor.prepare_training_data(prop_t_fg, False)
data_loader_validation = bbox_adaptor.prepare_validation_data(prop_t_fg)
bbox_adaptor.validate_baseline(data_loader_validation)
bbox_adaptor.fit(data_loader_source, data_loader_target, data_loader_validation)
# generate bounding box labels for each proposals
cache_feedback_root = os.path.join(cfg.OUTPUT_DIR, "cache", "feedback_bbox")
prop_t_fg_refined = generate_bounding_box_labels(
prop_t_fg, bbox_adaptor, classes,
os.path.join(cache_feedback_root, "{}_fg.json".format(args.targets[0]))
)
prop_t_bg_refined = generate_bounding_box_labels(
prop_t_bg, bbox_adaptor, classes,
os.path.join(cache_feedback_root, "{}_bg.json".format(args.targets[0]))
)
prop_t_fg += prop_t_fg_refined
prop_t_bg += prop_t_bg_refined
bbox_adaptor.model.to(torch.device("cpu"))
if args.reduce_proposals:
# remove proposals
prop_t_bg_new = []
for p in prop_t_bg:
keep_indices = p.pred_classes == len(classes)
prop_t_bg_new.append(p[keep_indices])
prop_t_bg = prop_t_bg_new
prop_t_fg_new = []
for p in prop_t_fg:
prop_t_fg_new.append(p[:20])
prop_t_fg = prop_t_fg_new
model = model.to(torch.device(cfg.MODEL.DEVICE))
# Data loading code
train_source_dataset = get_detection_dataset_dicts(args.sources)
train_source_loader = build_detection_train_loader(dataset=train_source_dataset, cfg=cfg)
train_target_dataset = get_detection_dataset_dicts(args.targets, proposals_list=prop_t_fg+prop_t_bg)
mapper = DatasetMapper(cfg, precomputed_proposal_topk=1000, augmentations=utils.build_augmentation(cfg, True))
train_target_loader = build_detection_train_loader(dataset=train_target_dataset, cfg=cfg, mapper=mapper,
total_batch_size=cfg.SOLVER.IMS_PER_BATCH)
# training the object detector
logger.info("Starting training from iteration {}".format(start_iter))
with EventStorage(start_iter) as storage:
for data_s, data_t, iteration in zip(train_source_loader, train_target_loader, range(start_iter, max_iter)):
storage.iter = iteration
optimizer.zero_grad()
# compute losses and gradient on source domain
loss_dict_s = model(data_s)
losses_s = sum(loss_dict_s.values())
assert torch.isfinite(losses_s).all(), loss_dict_s
loss_dict_reduced_s = {"{}_s".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_s).items()}
losses_reduced_s = sum(loss for loss in loss_dict_reduced_s.values())
losses_s.backward()
# compute losses and gradient on target domain
loss_dict_t = model(data_t, labeled=False)
losses_t = sum(loss_dict_t.values())
assert torch.isfinite(losses_t).all()
loss_dict_reduced_t = {"{}_t".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_t).items()}
(losses_t * args.trade_off).backward()
if comm.is_main_process():
storage.put_scalars(total_loss_s=losses_reduced_s, **loss_dict_reduced_s, **loss_dict_reduced_t)
# do SGD step
optimizer.step()
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
scheduler.step()
# evaluate on validation set
if (
cfg.TEST.EVAL_PERIOD > 0
and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
and iteration != max_iter - 1
):
utils.validate(model, logger, cfg, args)
comm.synchronize()
if iteration - start_iter > 5 and (
(iteration + 1) % 20 == 0 or iteration == max_iter - 1
):
for writer in writers:
writer.write()
periodic_checkpointer.step(iteration)
def main(args, args_cls, args_box):
logger = logging.getLogger("detectron2")
cfg = utils.setup(args)
# dataset
args.sources = utils.build_dataset(args.sources[::2], args.sources[1::2])
args.targets = utils.build_dataset(args.targets[::2], args.targets[1::2])
args.test = utils.build_dataset(args.test[::2], args.test[1::2])
# create model
model = models.__dict__[cfg.MODEL.META_ARCHITECTURE](cfg, finetune=args.finetune)
model.to(torch.device(cfg.MODEL.DEVICE))
logger.info("Model:\n{}".format(model))
if args.eval_only:
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
return utils.validate(model, logger, cfg, args)
distributed = comm.get_world_size() > 1
if distributed:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
train(model, logger, cfg, args, args_cls, args_box)
# evaluate on validation set
return utils.validate(model, logger, cfg, args)
if __name__ == "__main__":
args_cls, argv = category_adaptation.CategoryAdaptor.get_parser().parse_known_args()
print("Category Adaptation Args:")
pprint.pprint(args_cls)
args_box, argv = bbox_adaptation.BoundingBoxAdaptor.get_parser().parse_known_args(args=argv)
print("Bounding Box Adaptation Args:")
pprint.pprint(args_box)
parser = argparse.ArgumentParser(add_help=True)
# dataset parameters
parser.add_argument('-s', '--sources', nargs='+', help='source domain(s)')
parser.add_argument('-t', '--targets', nargs='+', help='target domain(s)')
parser.add_argument('--test', nargs='+', help='test domain(s)')
# model parameters
parser.add_argument('--finetune', action='store_true',
help='whether use 10x smaller learning rate for backbone')
parser.add_argument(
"--resume",
action="store_true",
help="Whether to attempt to resume from the checkpoint directory. "
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
)
parser.add_argument('--trade-off', default=1., type=float,
help='trade-off hyper-parameter for losses on target domain')
parser.add_argument('--bbox-refine', action='store_true',
help='whether perform bounding box refinement')
parser.add_argument('--reduce-proposals', action='store_true',
help='whether remove some low-quality proposals.'
'Helpful for RetinaNet')
# training parameters
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument("--machine-rank", type=int, default=0,
help="the rank of this machine (unique per machine)")
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument(
"--dist-url",
default="tcp://127.0.0.1:{}".format(port),
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"opts",
help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. "
"See config references at "
"https://detectron2.readthedocs.io/modules/config.html#config-references",
default=None,
nargs=argparse.REMAINDER,
)
args, argv = parser.parse_known_args(argv)
print("Detection Args:")
pprint.pprint(args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args, args_cls, args_box),
)
================================================
FILE: examples/domain_adaptation/object_detection/d_adapt/d_adapt.sh
================================================
# ResNet101 Based Faster RCNN: Faster RCNN: VOC->Clipart
# 44.8
pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \
-t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \
--finetune --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0
# 47.9
pretrained_models=logs/faster_rcnn_R_101_C4/voc2clipart/phase1/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \
-t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \
--finetune --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0
# 49.0
pretrained_models=logs/faster_rcnn_R_101_C4/voc2clipart/phase2/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.2 \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \
-t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \
--finetune --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase3 MODEL.WEIGHTS ${pretrained_models} SEED 0
# ResNet101 Based Faster RCNN: Faster RCNN: VOC->WaterColor
# 54.1
pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012 \
-t WaterColor ../datasets/watercolor --test WaterColorTest ../datasets/watercolor --finetune --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2watercolor/phase1 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0
# 57.5
pretrained_models=logs/faster_rcnn_R_101_C4/voc2watercolor/phase1/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012 \
-t WaterColor ../datasets/watercolor --test WaterColorTest ../datasets/watercolor --finetune --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2watercolor/phase2 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0
# ResNet101 Based Faster RCNN: Faster RCNN: VOC->Comic
# 39.7
pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012 \
-t Comic ../datasets/comic --test ComicTest ../datasets/comic --finetune --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2comic/phase1 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0
# 41.0
pretrained_models=logs/faster_rcnn_R_101_C4/voc2comic/phase1/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012 \
-t Comic ../datasets/comic --test ComicTest ../datasets/comic --finetune --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2comic/phase2 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0
# ResNet101 Based Faster RCNN: Cityscapes -> Foggy Cityscapes
# 40.1
pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/cityscapes2foggy/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \
--test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0
# 42.4
pretrained_models=logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase1/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.1 \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \
--test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0
# VGG Based Faster RCNN: Cityscapes -> Foggy Cityscapes
# 33.3
pretrained_models=../logs/source_only/faster_rcnn_vgg_16/cityscapes2foggy/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \
--test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0
# 37.0
pretrained_models=logs/faster_rcnn_vgg_16/cityscapes2foggy/phase1/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.1 \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \
--test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0
# 38.9
pretrained_models=logs/faster_rcnn_vgg_16/cityscapes2foggy/phase2/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.2 \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \
--test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase3 MODEL.WEIGHTS ${pretrained_models} SEED 0
# ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car
# 51.9
pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/sim10k2cityscapes_car/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 8 --ignored-scores-c 0.05 0.5 --bottleneck-dim-c 256 --bottleneck-dim-b 256 \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s Sim10kCar ../datasets/sim10k -t CityscapesCar ../datasets/cityscapes_in_voc/ \
--test CityscapesCarTest ../datasets/cityscapes_in_voc/ --finetune --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_R_101_C4/sim10k2cityscapes_car/phase1 MODEL.ROI_HEADS.NUM_CLASSES 1 MODEL.WEIGHTS ${pretrained_models} SEED 0
# VGG Based Faster RCNN: Sim10k -> Cityscapes Car
# 49.3
pretrained_models=../logs/source_only/faster_rcnn_vgg_16/sim10k2cityscapes_car/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 8 --ignored-scores-c 0.05 0.5 --bottleneck-dim-c 256 --bottleneck-dim-b 256 \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s Sim10kCar ../datasets/sim10k -t CityscapesCar ../datasets/cityscapes_in_voc/ \
--test CityscapesCarTest ../datasets/cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \
OUTPUT_DIR logs/faster_rcnn_vgg_16/sim10k2cityscapes_car/phase1 MODEL.ROI_HEADS.NUM_CLASSES 1 MODEL.WEIGHTS ${pretrained_models} SEED 0
# RetinaNet: VOC->Clipart
# 44.7
pretrained_models=../logs/source_only/retinanet_R_101_FPN/voc2clipart/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --remove-bg \
--config-file config/retinanet_R_101_FPN_voc.yaml \
-s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \
-t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \
--finetune --bbox-refine \
OUTPUT_DIR logs/retinanet_R_101_FPN/voc2clipart/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0
# 46.3
pretrained_models=logs/retinanet_R_101_FPN/voc2clipart/phase1/model_final.pth
CUDA_VISIBLE_DEVICES=0 python d_adapt.py --remove-bg --confidence-ratio 0.1 \
--config-file config/retinanet_R_101_FPN_voc.yaml \
-s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \
-t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \
--finetune --bbox-refine \
OUTPUT_DIR logs/retinanet_R_101_FPN/voc2clipart/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0
================================================
FILE: examples/domain_adaptation/object_detection/oracle.sh
================================================
# Faster RCNN: WaterColor
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s WaterColor datasets/watercolor -t WaterColor datasets/watercolor \
--test WaterColorTest datasets/watercolor --finetune \
OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/watercolor MODEL.ROI_HEADS.NUM_CLASSES 6
# Faster RCNN: Comic
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s Comic datasets/comic -t Comic datasets/comic \
--test ComicTest datasets/comic --finetune \
OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/comic MODEL.ROI_HEADS.NUM_CLASSES 6
# ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s FoggyCityscapes datasets/foggy_cityscapes_in_voc -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \
--test FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \
OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/cityscapes2foggy
# VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s FoggyCityscapes datasets/foggy_cityscapes_in_voc -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \
--test FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \
OUTPUT_DIR logs/oracle/faster_rcnn_vgg_16/cityscapes2foggy
# ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s CityscapesCar datasets/cityscapes_in_voc/ -t CityscapesCar datasets/cityscapes_in_voc/ \
--test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \
OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1
# VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s CityscapesCar datasets/cityscapes_in_voc/ -t CityscapesCar datasets/cityscapes_in_voc/ \
--test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \
OUTPUT_DIR logs/oracle/faster_rcnn_vgg_16/cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1
================================================
FILE: examples/domain_adaptation/object_detection/prepare_cityscapes_to_voc.py
================================================
from pascal_voc_writer import Writer
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import glob
import time
from shutil import move, copy
import tqdm
classes = {'bicycle': 'bicycle', 'bus': 'bus', 'car': 'car', 'motorcycle': 'motorcycle',
'person': 'person', 'rider': 'rider', 'train': 'train', 'truck': 'truck'}
classes_keys = list(classes.keys())
def make_dir(path):
if not os.path.isdir(path):
os.makedirs(path)
#----------------------------------------------------------------------------------------------------------------
#convert polygon to bounding box
#code from:
#https://stackoverflow.com/questions/46335488/how-to-efficiently-find-the-bounding-box-of-a-collection-of-points
#----------------------------------------------------------------------------------------------------------------
def polygon_to_bbox(polygon):
x_coordinates, y_coordinates = zip(*polygon)
return [min(x_coordinates), min(y_coordinates), max(x_coordinates), max(y_coordinates)]
# --------------------------------------------
# read a json file and convert to voc format
# --------------------------------------------
def read_json(file):
# if no relevant objects found in the image,
# don't save the xml for the image
relevant_file = False
data = []
with open(file, 'r') as f:
file_data = json.load(f)
for object in file_data['objects']:
label, polygon = object['label'], object['polygon']
# process only if label found in voc
if label in classes_keys:
polygon = np.array([x for x in polygon])
bbox = polygon_to_bbox(polygon)
data.append([classes[label]] + bbox)
# if relevant objects found in image, set the flag to True
if data:
relevant_file = True
return data, relevant_file
#---------------------------
#function to save xml file
#---------------------------
def save_xml(img_path, img_shape, data, save_path):
writer = Writer(img_path,img_shape[1], img_shape[0])
for element in data:
writer.addObject(element[0],element[1],element[2],element[3],element[4])
writer.save(save_path)
def prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir):
cityscapes_dir_gt = os.path.join(cityscapes_dir, 'gtFine')
# ------------------------------------------
# reading json files from each subdirectory
# ------------------------------------------
valid_files = []
trainval_files = []
test_files = []
# make Annotations target directory if already doesn't exist
ann_dir = os.path.join(save_path, 'Annotations')
make_dir(ann_dir)
start = time.time()
for category in os.listdir(cityscapes_dir_gt):
# # no GT for test data
# if category == 'test':
# continue
for city in tqdm.tqdm(os.listdir(os.path.join(cityscapes_dir_gt, category))):
# read files
files = glob.glob(os.path.join(cityscapes_dir, 'gtFine', category, city) + '/*.json')
# process json files
for file in files:
data, relevant_file = read_json(file)
if relevant_file:
base_filename = os.path.basename(file)[:-21]
xml_filepath = os.path.join(ann_dir, base_filename + '{}.xml'.format(suffix))
img_name = base_filename + '{}.png'.format(suffix)
img_path = os.path.join(cityscapes_dir, image_dir, category, city,
base_filename + '{}.png'.format(suffix))
img_shape = plt.imread(img_path).shape
valid_files.append([img_path, img_name])
# make list of trainval and test files for voc format
# lists will be stored in txt files
trainval_files.append(img_name[:-4]) if category == 'train' else test_files.append(img_name[:-4])
# save xml file
save_xml(os.path.join(image_dir, category, city,
base_filename + '{}.png'.format(suffix)), img_shape, data, xml_filepath)
end = time.time() - start
print('Total Time taken: ', end)
# ----------------------------
# copy files into target path
# ----------------------------
images_savepath = os.path.join(save_path, 'JPEGImages')
make_dir(images_savepath)
start = time.time()
for file in valid_files:
copy(file[0], os.path.join(images_savepath, file[1]))
print('Total Time taken: ', end)
# ---------------------------------------------
# create text files of trainval and test files
# ---------------------------------------------
textfiles_savepath = os.path.join(save_path, 'ImageSets', 'Main')
make_dir(textfiles_savepath)
traival_files_wr = [x + '\n' for x in trainval_files]
test_files_wr = [x + '\n' for x in test_files]
with open(os.path.join(textfiles_savepath, 'trainval.txt'), 'w') as f:
f.writelines(traival_files_wr)
with open(os.path.join(textfiles_savepath, 'test.txt'), 'w') as f:
f.writelines(test_files_wr)
if __name__ == '__main__':
cityscapes_dir = 'datasets/cityscapes/'
if not os.path.exists(cityscapes_dir):
print("Please put cityscapes datasets in: {}".format(cityscapes_dir))
exit(0)
save_path = 'datasets/cityscapes_in_voc/'
suffix = "_leftImg8bit"
image_dir = "leftImg8bit"
prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir)
save_path = 'datasets/foggy_cityscapes_in_voc/'
suffix = "_leftImg8bit_foggy_beta_0.02"
image_dir = "leftImg8bit_foggy"
prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir)
================================================
FILE: examples/domain_adaptation/object_detection/requirements.txt
================================================
mmcv
timm
prettytable
pascal_voc_writer
================================================
FILE: examples/domain_adaptation/object_detection/source_only.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import logging
import os
import argparse
import sys
import torch
from torch.nn.parallel import DistributedDataParallel
from detectron2.engine import default_writers, launch
from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer
import detectron2.utils.comm as comm
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
from detectron2.solver import build_lr_scheduler
from detectron2.data import (
build_detection_train_loader,
get_detection_dataset_dicts,
)
from detectron2.utils.events import EventStorage
sys.path.append('../../..')
import tllib.vision.models.object_detection.meta_arch as models
import utils
def train(model, logger, cfg, args):
model.train()
distributed = comm.get_world_size() > 1
if distributed:
model_without_parallel = model.module
else:
model_without_parallel = model
# define optimizer and lr scheduler
params = []
for module, lr in model_without_parallel.get_parameters(cfg.SOLVER.BASE_LR):
params.extend(
get_default_optimizer_params(
module,
base_lr=lr,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
)
)
optimizer = maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params,
lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
scheduler = build_lr_scheduler(cfg, optimizer)
# resume from the last checkpoint
checkpointer = DetectionCheckpointer(
model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler
)
start_iter = (
checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume).get("iteration", -1) + 1
)
max_iter = cfg.SOLVER.MAX_ITER
periodic_checkpointer = PeriodicCheckpointer(
checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter
)
writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else []
# Data loading code
train_source_dataset = get_detection_dataset_dicts(args.source)
train_source_loader = build_detection_train_loader(dataset=train_source_dataset, cfg=cfg)
# start training
logger.info("Starting training from iteration {}".format(start_iter))
with EventStorage(start_iter) as storage:
for data_s, iteration in zip(train_source_loader, range(start_iter, max_iter)):
storage.iter = iteration
# compute output
_, loss_dict_s = model(data_s)
losses_s = sum(loss_dict_s.values())
assert torch.isfinite(losses_s).all(), loss_dict_s
loss_dict_reduced_s = {"{}_s".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_s).items()}
losses_reduced_s = sum(loss for loss in loss_dict_reduced_s.values())
if comm.is_main_process():
storage.put_scalars(total_loss_s=losses_reduced_s, **loss_dict_reduced_s)
# compute gradient and do SGD step
optimizer.zero_grad()
losses_s.backward()
optimizer.step()
storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False)
scheduler.step()
# evaluate on validation set
if (
cfg.TEST.EVAL_PERIOD > 0
and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0
and iteration != max_iter - 1
):
utils.validate(model, logger, cfg, args)
comm.synchronize()
if iteration - start_iter > 5 and (
(iteration + 1) % 20 == 0 or iteration == max_iter - 1
):
for writer in writers:
writer.write()
periodic_checkpointer.step(iteration)
def main(args):
logger = logging.getLogger("detectron2")
cfg = utils.setup(args)
# dataset
args.source = utils.build_dataset(args.source[::2], args.source[1::2])
args.target = utils.build_dataset(args.target[::2], args.target[1::2])
args.test = utils.build_dataset(args.test[::2], args.test[1::2])
# create model
model = models.__dict__[cfg.MODEL.META_ARCHITECTURE](cfg, finetune=args.finetune)
model.to(torch.device(cfg.MODEL.DEVICE))
logger.info("Model:\n{}".format(model))
if args.eval_only:
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=args.resume
)
return utils.validate(model, logger, cfg, args)
distributed = comm.get_world_size() > 1
if distributed:
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False
)
train(model, logger, cfg, args)
# evaluate on validation set
return utils.validate(model, logger, cfg, args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# dataset parameters
parser.add_argument('-s', '--source', nargs='+', help='source domain(s)')
parser.add_argument('-t', '--target', nargs='+', help='target domain(s)')
parser.add_argument('--test', nargs='+', help='test domain(s)')
# model parameters
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller learning rate for backbone')
parser.add_argument(
"--resume",
action="store_true",
help="Whether to attempt to resume from the checkpoint directory. "
"See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
)
# training parameters
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument("--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)")
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument(
"--dist-url",
default="tcp://127.0.0.1:{}".format(port),
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"opts",
help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. "
"See config references at "
"https://detectron2.readthedocs.io/modules/config.html#config-references",
default=None,
nargs=argparse.REMAINDER,
)
args = parser.parse_args()
print("Command Line Args:", args)
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
================================================
FILE: examples/domain_adaptation/object_detection/source_only.sh
================================================
# Faster RCNN: VOC->Clipart
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \
--test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \
OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2clipart
# Faster RCNN: VOC->WaterColor, Comic
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_voc.yaml \
-s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 -t WaterColor datasets/watercolor Comic datasets/comic \
--test VOC2007PartialTest datasets/VOC2007 WaterColorTest datasets/watercolor ComicTest datasets/comic --finetune \
OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic MODEL.ROI_HEADS.NUM_CLASSES 6
# ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s Cityscapes datasets/cityscapes_in_voc/ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \
--test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \
OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/cityscapes2foggy
# VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s Cityscapes datasets/cityscapes_in_voc/ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \
--test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \
OUTPUT_DIR logs/source_only/faster_rcnn_vgg_16/cityscapes2foggy
# ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s Sim10kCar datasets/sim10k -t CityscapesCar datasets/cityscapes_in_voc/ \
--test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \
OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1
# VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_vgg_16_cityscapes.yaml \
-s Sim10kCar datasets/sim10k -t CityscapesCar datasets/cityscapes_in_voc/ \
--test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \
OUTPUT_DIR logs/source_only/faster_rcnn_vgg_16/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1
# Faster RCNN: GTA5 -> Cityscapes
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \
-s GTA5 datasets/synscapes_detection -t Cityscapes datasets/cityscapes_in_voc/ \
--test CityscapesTest datasets/cityscapes_in_voc/ --finetune \
OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/gta52cityscapes
# RetinaNet: VOC->Clipart
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/retinanet_R_101_FPN_voc.yaml \
-s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \
--test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \
OUTPUT_DIR logs/source_only/retinanet_R_101_FPN/voc2clipart
# RetinaNet: VOC->WaterColor, Comic
CUDA_VISIBLE_DEVICES=0 python source_only.py \
--config-file config/retinanet_R_101_FPN_voc.yaml \
-s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 -t WaterColor datasets/watercolor Comic datasets/comic \
--test VOC2007PartialTest datasets/VOC2007 WaterColorTest datasets/watercolor ComicTest datasets/comic --finetune \
OUTPUT_DIR logs/source_only/retinanet_R_101_FPN/voc2watercolor_comic MODEL.RETINANET.NUM_CLASSES 6
================================================
FILE: examples/domain_adaptation/object_detection/utils.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import numpy as np
import os
import prettytable
from typing import *
from collections import OrderedDict, defaultdict
import tempfile
import logging
import matplotlib as mpl
import torch
import torch.nn as nn
import torchvision.transforms as T
import detectron2.utils.comm as comm
from detectron2.evaluation import PascalVOCDetectionEvaluator, inference_on_dataset
from detectron2.evaluation.pascal_voc_evaluation import voc_eval
from detectron2.config import get_cfg, CfgNode
from detectron2.engine import default_setup
from detectron2.data import (
build_detection_test_loader,
)
from detectron2.data.transforms.augmentation import Augmentation
from detectron2.data.transforms import BlendTransform, ColorTransform
from detectron2.solver.lr_scheduler import LRMultiplier, WarmupParamScheduler
from detectron2.utils.visualizer import Visualizer
from detectron2.utils.colormap import random_color
from fvcore.common.param_scheduler import *
import timm
import tllib.vision.datasets.object_detection as datasets
import tllib.vision.models as models
class PascalVOCDetectionPerClassEvaluator(PascalVOCDetectionEvaluator):
"""
Evaluate Pascal VOC style AP with per-class AP for Pascal VOC dataset.
It contains a synchronization, therefore has to be called from all ranks.
Note that the concept of AP can be implemented in different ways and may not
produce identical results. This class mimics the implementation of the official
Pascal VOC Matlab API, and should produce similar but not identical results to the
official API.
"""
def evaluate(self):
"""
Returns:
dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
"""
all_predictions = comm.gather(self._predictions, dst=0)
if not comm.is_main_process():
return
predictions = defaultdict(list)
for predictions_per_rank in all_predictions:
for clsid, lines in predictions_per_rank.items():
predictions[clsid].extend(lines)
del all_predictions
self._logger.info(
"Evaluating {} using {} metric. "
"Note that results do not use the official Matlab API.".format(
self._dataset_name, 2007 if self._is_2007 else 2012
)
)
with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
res_file_template = os.path.join(dirname, "{}.txt")
aps = defaultdict(list) # iou -> ap per class
for cls_id, cls_name in enumerate(self._class_names):
lines = predictions.get(cls_id, [""])
with open(res_file_template.format(cls_name), "w") as f:
f.write("\n".join(lines))
for thresh in range(50, 100, 5):
rec, prec, ap = voc_eval(
res_file_template,
self._anno_file_template,
self._image_set_path,
cls_name,
ovthresh=thresh / 100.0,
use_07_metric=self._is_2007,
)
aps[thresh].append(ap * 100)
ret = OrderedDict()
mAP = {iou: np.mean(x) for iou, x in aps.items()}
ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]}
for cls_name, ap in zip(self._class_names, aps[50]):
ret["bbox"][cls_name] = ap
return ret
def validate(model, logger, cfg, args):
results = OrderedDict()
for dataset_name in args.test:
data_loader = build_detection_test_loader(cfg, dataset_name)
evaluator = PascalVOCDetectionPerClassEvaluator(dataset_name)
results_i = inference_on_dataset(model, data_loader, evaluator)
results[dataset_name] = results_i
if comm.is_main_process():
logger.info(results_i)
table = prettytable.PrettyTable(["class", "AP"])
for class_name, ap in results_i["bbox"].items():
table.add_row([class_name, ap])
logger.info(table.get_string())
if len(results) == 1:
results = list(results.values())[0]
return results
def build_dataset(dataset_categories, dataset_roots):
"""
Give a sequence of dataset class name and a sequence of dataset root directory,
return a sequence of built datasets
"""
dataset_lists = []
for dataset_category, root in zip(dataset_categories, dataset_roots):
dataset_lists.append(datasets.__dict__[dataset_category](root).name)
return dataset_lists
def rgb2gray(rgb):
return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])[:, :, np.newaxis].repeat(3, axis=2).astype(rgb.dtype)
class Grayscale(Augmentation):
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self._init(locals())
self._transform = T.Grayscale()
def get_transform(self, image):
return ColorTransform(lambda x: rgb2gray(x))
def build_augmentation(cfg, is_train):
"""
Create a list of default :class:`Augmentation` from config.
Now it includes resizing and flipping.
Returns:
list[Augmentation]
"""
import detectron2.data.transforms as T
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
else:
min_size = cfg.INPUT.MIN_SIZE_TEST
max_size = cfg.INPUT.MAX_SIZE_TEST
sample_style = "choice"
augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)]
if is_train and cfg.INPUT.RANDOM_FLIP != "none":
augmentation.append(
T.RandomFlip(
horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal",
vertical=cfg.INPUT.RANDOM_FLIP == "vertical",
)
)
augmentation.append(
T.RandomApply(T.AugmentationList(
[
T.RandomContrast(0.6, 1.4),
T.RandomBrightness(0.6, 1.4),
T.RandomSaturation(0.6, 1.4),
T.RandomLighting(0.1)
]
), prob=0.8)
)
augmentation.append(
T.RandomApply(Grayscale(), prob=0.2)
)
return augmentation
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(
cfg, args
) # if you don't like any of the default setup, write your own setup code
return cfg
def build_lr_scheduler(
cfg: CfgNode, optimizer: torch.optim.Optimizer
) -> torch.optim.lr_scheduler._LRScheduler:
"""
Build a LR scheduler from config.
"""
name = cfg.SOLVER.LR_SCHEDULER_NAME
if name == "WarmupMultiStepLR":
steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER]
if len(steps) != len(cfg.SOLVER.STEPS):
logger = logging.getLogger(__name__)
logger.warning(
"SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. "
"These values will be ignored."
)
sched = MultiStepParamScheduler(
values=[cfg.SOLVER.GAMMA ** k for k in range(len(steps) + 1)],
milestones=steps,
num_updates=cfg.SOLVER.MAX_ITER,
)
elif name == "WarmupCosineLR":
sched = CosineParamScheduler(1, 0)
elif name == "ExponentialLR":
sched = ExponentialParamScheduler(1, cfg.SOLVER.GAMMA)
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
else:
raise ValueError("Unknown LR scheduler: {}".format(name))
sched = WarmupParamScheduler(
sched,
cfg.SOLVER.WARMUP_FACTOR,
cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER,
cfg.SOLVER.WARMUP_METHOD,
)
return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER)
def get_model_names():
return sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
) + timm.list_models()
def get_model(model_name, pretrain=True):
if model_name in models.__dict__:
# load models from common.vision.models
backbone = models.__dict__[model_name](pretrained=pretrain)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=pretrain)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
return backbone
class VisualizerWithoutAreaSorting(Visualizer):
"""
Visualizer in detectron2 draw instances according to their area's order.
This visualizer removes sorting code to avoid that boxes with lower confidence
cover boxes with higher confidence.
"""
def __init__(self, *args, flip=False, **kwargs):
super(VisualizerWithoutAreaSorting, self).__init__(*args, **kwargs)
self.flip = flip
def overlay_instances(
self,
*,
boxes=None,
labels=None,
masks=None,
keypoints=None,
assigned_colors=None,
alpha=1
):
"""
Args:
boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`,
or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image,
or a :class:`RotatedBoxes`,
or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format
for the N objects in a single image,
labels (list[str]): the text to be displayed for each instance.
masks (masks-like object): Supported types are:
* :class:`detectron2.structures.PolygonMasks`,
:class:`detectron2.structures.BitMasks`.
* list[list[ndarray]]: contains the segmentation masks for all objects in one image.
The first level of the list corresponds to individual instances. The second
level to all the polygon that compose the instance, and the third level
to the polygon coordinates. The third level should have the format of
[x0, y0, x1, y1, ..., xn, yn] (n >= 3).
* list[ndarray]: each ndarray is a binary mask of shape (H, W).
* list[dict]: each dict is a COCO-style RLE.
keypoints (Keypoint or array like): an array-like object of shape (N, K, 3),
where the N is the number of instances and K is the number of keypoints.
The last dimension corresponds to (x, y, visibility or score).
assigned_colors (list[matplotlib.colors]): a list of colors, where each color
corresponds to each mask or box in the image. Refer to 'matplotlib.colors'
for full list of formats that the colors are accepted in.
Returns:
output (VisImage): image object with visualizations.
"""
num_instances = None
if boxes is not None:
boxes = self._convert_boxes(boxes)
num_instances = len(boxes)
if masks is not None:
masks = self._convert_masks(masks)
if num_instances:
assert len(masks) == num_instances
else:
num_instances = len(masks)
if keypoints is not None:
if num_instances:
assert len(keypoints) == num_instances
else:
num_instances = len(keypoints)
keypoints = self._convert_keypoints(keypoints)
if labels is not None:
assert len(labels) == num_instances
if assigned_colors is None:
assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)]
if num_instances == 0:
return self.output
if boxes is not None and boxes.shape[1] == 5:
return self.overlay_rotated_instances(
boxes=boxes, labels=labels, assigned_colors=assigned_colors
)
for i in range(num_instances):
color = assigned_colors[i]
if boxes is not None:
self.draw_box(boxes[i], edge_color=color)
if masks is not None:
for segment in masks[i].polygons:
self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha)
if labels is not None:
# first get a box
if boxes is not None:
x0, y0, x1, y1 = boxes[i]
text_pos = (x0 - 3, y0) # if drawing boxes, put text on the box corner.
horiz_align = "left"
elif masks is not None:
# skip small mask without polygon
if len(masks[i].polygons) == 0:
continue
x0, y0, x1, y1 = masks[i].bbox()
# draw text in the center (defined by median) when box is not drawn
# median is less sensitive to outliers.
text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1]
horiz_align = "center"
else:
continue # drawing the box confidence for keypoints isn't very useful.
# for small objects, draw text at the side to avoid occlusion
instance_area = (y1 - y0) * (x1 - x0)
if (
instance_area < 1000 * self.output.scale
or y1 - y0 < 40 * self.output.scale
):
if y1 >= self.output.height - 5:
text_pos = (x1, y0)
else:
text_pos = (x0, y1)
height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width)
lighter_color = self._change_color_brightness(color, brightness_factor=0.7)
font_size = (
np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2)
* 1
* self._default_font_size
)
if self.flip:
text_pos = (x1 - 3, y0 - 30) # if drawing boxes, put text on the box corner.
self.draw_text(
labels[i],
text_pos,
color=lighter_color,
horizontal_alignment=horiz_align,
font_size=font_size,
)
# draw keypoints
if keypoints is not None:
for keypoints_per_instance in keypoints:
self.draw_and_connect_keypoints(keypoints_per_instance)
return self.output
def draw_box(self, box_coord, alpha=1, edge_color="g", line_style="-"):
x0, y0, x1, y1 = box_coord
width = x1 - x0
height = y1 - y0
linewidth = max(self._default_font_size / 4, 3)
self.output.ax.add_patch(
mpl.patches.Rectangle(
(x0, y0),
width,
height,
fill=False,
edgecolor=edge_color,
linewidth=linewidth * self.output.scale,
alpha=alpha,
linestyle=line_style,
)
)
return self.output
================================================
FILE: examples/domain_adaptation/object_detection/visualize.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import os
import argparse
import sys
import torch
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import (
build_detection_test_loader,
MetadataCatalog
)
from detectron2.data import detection_utils
from detectron2.engine import default_setup, launch
from detectron2.utils.visualizer import ColorMode
sys.path.append('../../..')
import tllib.vision.models.object_detection.meta_arch as models
import utils
def visualize(cfg, args, model):
for dataset_name in args.test:
data_loader = build_detection_test_loader(cfg, dataset_name)
# create folder
dirname = os.path.join(args.save_path, dataset_name)
os.makedirs(dirname, exist_ok=True)
metadata = MetadataCatalog.get(dataset_name)
n_current = 0
# switch to eval mode
model.eval()
with torch.no_grad():
for batch in data_loader:
if n_current >= args.n_visualizations:
break
batch_predictions = model(batch)
for per_image, predictions in zip(batch, batch_predictions):
instances = predictions["instances"].to(torch.device("cpu"))
# only visualize boxes with highest confidence
instances = instances[0: args.n_bboxes]
# only visualize boxes with confidence exceeding the threshold
instances = instances[instances.scores > args.threshold]
# visualize in reverse order of confidence
index = [i for i in range(len(instances))]
index.reverse()
instances = instances[index]
img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy()
img = detection_utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT)
# scale pred_box to original resolution
ori_height, ori_width, _ = img.shape
height, width = instances.image_size
ratio = ori_width / width
for i in range(len(instances.pred_boxes)):
instances.pred_boxes[i].scale(ratio, ratio)
# save original image
visualizer = utils.VisualizerWithoutAreaSorting(img, metadata=metadata,
instance_mode=ColorMode.IMAGE)
output = visualizer.draw_instance_predictions(predictions=instances)
filepath = str(n_current) + ".png"
filepath = os.path.join(dirname, filepath)
output.save(filepath)
n_current += 1
if n_current >= args.n_visualizations:
break
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(
cfg, args
) # if you don't like any of the default setup, write your own setup code
return cfg
def main(args):
cfg = setup(args)
meta_arch = cfg.MODEL.META_ARCHITECTURE
model = models.__dict__[meta_arch](cfg, finetune=True)
model.to(torch.device(cfg.MODEL.DEVICE))
DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
cfg.MODEL.WEIGHTS, resume=False
)
visualize(cfg, args, model)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
parser.add_argument(
"--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
)
# PyTorch still may leave orphan processes in multi-gpu training.
# Therefore we use a deterministic way to obtain port,
# so that users are aware of orphan processes by seeing the port occupied.
port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14
parser.add_argument(
"--dist-url",
default="tcp://127.0.0.1:{}".format(port),
help="initialization URL for pytorch distributed backend. See "
"https://pytorch.org/docs/stable/distributed.html for details.",
)
parser.add_argument(
"opts",
help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. "
"See config references at "
"https://detectron2.readthedocs.io/modules/config.html#config-references",
default=None,
nargs=argparse.REMAINDER,
)
parser.add_argument('--test', nargs='+', help='test domain(s)')
parser.add_argument('--save-path', type=str,
help='where to save visualization results ')
parser.add_argument('--n-visualizations', default=100, type=int,
help='maximum number of images to visualize (default: 100)')
parser.add_argument('--threshold', default=0.5, type=float,
help='confidence threshold of bounding boxes to visualize (default: 0.5)')
parser.add_argument('--n-bboxes', default=10, type=int,
help='maximum number of bounding boxes to visualize in a single image (default: 10)')
args = parser.parse_args()
print("Command Line Args:", args)
args.test = utils.build_dataset(args.test[::2], args.test[1::2])
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)
================================================
FILE: examples/domain_adaptation/object_detection/visualize.sh
================================================
# Source Only Faster RCNN: VOC->Clipart
CUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \
--test Clipart datasets/clipart --save-path visualizations/source_only/voc2clipart \
MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth
# Source Only Faster RCNN: VOC->WaterColor, Comic
CUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \
--test WaterColorTest datasets/watercolor ComicTest datasets/comic --save-path visualizations/source_only/voc2comic_watercolor \
MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth
================================================
FILE: examples/domain_adaptation/openset_domain_adaptation/README.md
================================================
# Open-set Domain Adaptation for Image Classification
## Installation
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).
You also need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
Following datasets can be downloaded automatically:
- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)
- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)
- [VisDA2017](http://ai.bu.edu/visda-2017/)
## Supported Methods
Supported methods include:
- [Open Set Domain Adaptation (OSBP)](https://arxiv.org/abs/1804.10427)
## Experiment and Results
The shell files give the script to reproduce the benchmark with specified hyper-parameters.
For example, if you want to train DANN on Office31, use the following script
```shell script
# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.
# Assume you have put the datasets under the path `data/office-31`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W
```
**Notations**
- ``Origin`` means the accuracy reported by the original paper.
- ``Avg`` is the accuracy reported by `TLlib`.
- ``ERM`` refers to the model trained with data from the source domain.
We report ``HOS`` used in [ROS (ECCV 2020)](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610409.pdf) to better measure the abilities of different open set domain adaptation algorithms.
We report the best ``HOS`` in all epochs.
DANN (baseline model) will degrade performance as training progresses, thus the
final ``HOS`` will be much lower than reported.
In contrast, OSBP will improve performance stably.
### Office-31 H-Score on ResNet-50
| Methods | Avg | A → W | D → W | W → D | A → D | D → A | W → A |
|-------------|------|-------|-------|-------|-------|-------|-------|
| ERM | 75.9 | 67.7 | 85.7 | 91.4 | 72.1 | 68.4 | 67.8 |
| DANN | 80.4 | 81.4 | 89.1 | 92.0 | 82.5 | 66.7 | 70.4 |
| OSBP | 87.8 | 90.7 | 96.4 | 97.5 | 88.7 | 77.0 | 76.7 |
### Office-Home HOS on ResNet-50
| Methods | Origin | Avg | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr |
|-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|
| Source Only | / | 59.8 | 55.2 | 65.2 | 71.4 | 52.8 | 59.6 | 65.2 | 55.8 | 44.8 | 68.0 | 63.8 | 49.4 | 68.0 |
| DANN | / | 64.8 | 55.2 | 65.2 | 71.4 | 52.8 | 59.6 | 65.2 | 55.8 | 44.8 | 68.0 | 63.8 | 49.4 | 68.0 |
| OSBP | 64.7 | 68.6 | 62.0 | 70.8 | 76.5 | 66.4 | 68.8 | 73.8 | 65.8 | 57.1 | 75.4 | 70.6 | 60.6 | 75.9 |
### VisDA-2017 performance on ResNet-50
| Methods | HOS | OS | OS* | UNK | bcycl | bus | car | mcycl | train | truck |
|-------------|------|------|------|------|-------|------|------|-------|-------|-------|
| Source Only | 42.6 | 37.6 | 34.7 | 55.1 | 42.6 | 6.4 | 30.5 | 67.1 | 84.0 | 0.2 |
| DANN | 57.8 | 50.4 | 45.6 | 78.9 | 20.1 | 71.4 | 29.5 | 74.4 | 67.8 | 10.4 |
| OSBP | 75.4 | 67.3 | 62.9 | 94.3 | 63.7 | 75.9 | 49.6 | 74.4 | 86.2 | 27.3 |
## Citation
If you use these methods in your research, please consider citing.
```
@InProceedings{OSBP,
author = {Saito, Kuniaki and Yamamoto, Shohei and Ushiku, Yoshitaka and Harada, Tatsuya},
title = {Open Set Domain Adaptation by Backpropagation},
booktitle = {ECCV},
year = {2018}
}
```
================================================
FILE: examples/domain_adaptation/openset_domain_adaptation/dann.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.modules.classifier import Classifier
from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy, ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to(
device)
domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv = DomainAdversarialLoss(domain_discri).to(device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = validate(test_loader, classifier, args)
print(acc1)
return
# start training
best_h_score = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
h_score = validate(val_loader, classifier, args)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if h_score > best_h_score:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_h_score = max(h_score, best_h_score)
print("best_h_score = {:3.1f}".format(best_h_score))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
h_score = validate(test_loader, classifier, args)
print("test_h_score = {:3.1f}".format(h_score))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':5.2f')
data_time = AverageMeter('Data', ':5.2f')
losses = AverageMeter('Loss', ':6.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_t, labels_t = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = domain_adv(f_s, f_t)
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
tgt_acc = accuracy(y_t, labels_t)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
tgt_accs.update(tgt_acc.item(), x_s.size(0))
domain_accs.update(domain_acc.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
def validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float:
batch_time = AverageMeter('Time', ':6.3f')
classes = val_loader.dataset.classes
confmat = ConfusionMatrix(len(classes))
progress = ProgressMeter(
len(val_loader),
[batch_time],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
softmax_output = F.softmax(output, dim=1)
softmax_output[:, -1] = args.threshold
# measure accuracy and record loss
confmat.update(target, softmax_output.argmax(1))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
acc_global, accs, iu = confmat.compute()
all_acc = torch.mean(accs).item() * 100
known = torch.mean(accs[:-1]).item() * 100
unknown = accs[-1].item() * 100
h_score = 2 * known * unknown / (known + unknown)
if args.per_class_eval:
print(confmat.format(classes))
print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'
.format(all=all_acc, known=known, unknown=unknown, h_score=h_score))
return h_score
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DANN for Openset Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--threshold', default=0.8, type=float,
help='When class confidence is less than the given threshold, '
'model will output "unknown" (default: 0.5)')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.002, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='dann',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/openset_domain_adaptation/dann.sh
================================================
#!/usr/bin/env bash
# Office31
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_W2A
# Office-Home
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Pr
# VisDA-2017
CUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \
--epochs 30 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/dann/VisDA2017_S2R
================================================
FILE: examples/domain_adaptation/openset_domain_adaptation/erm.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.classifier import Classifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy, ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, _, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# using shuffled val loader
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(val_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = validate(test_loader, classifier, args)
print(acc1)
return
# start training
best_h_score = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, classifier, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
h_score = validate(val_loader, classifier, args)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if h_score > best_h_score:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_h_score = max(h_score, best_h_score)
print("best_h_score = {:3.1f}".format(best_h_score))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
h_score = validate(test_loader, classifier, args)
print("test_h_score = {:3.1f}".format(h_score))
logger.close()
def train(train_source_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_s = x_s.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
cls_loss = F.cross_entropy(y_s, labels_s)
loss = cls_loss
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
def validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float:
batch_time = AverageMeter('Time', ':6.3f')
classes = val_loader.dataset.classes
confmat = ConfusionMatrix(len(classes))
progress = ProgressMeter(
len(val_loader),
[batch_time],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
softmax_output = F.softmax(output, dim=1)
softmax_output[:, -1] = args.threshold
# measure accuracy and record loss
confmat.update(target, softmax_output.argmax(1))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
acc_global, accs, iu = confmat.compute()
all_acc = torch.mean(accs).item() * 100
known = torch.mean(accs[:-1]).item() * 100
unknown = accs[-1].item() * 100
h_score = 2 * known * unknown / (known + unknown)
if args.per_class_eval:
print(confmat.format(classes))
print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'
.format(all=all_acc, known=known, unknown=unknown, h_score=h_score))
return h_score
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Source Only for Openset Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--threshold', default=0.8, type=float,
help='When class confidence is less than the given threshold, '
'model will output "unknown" (default: 0.5)')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='src_only',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/openset_domain_adaptation/erm.sh
================================================
#!/usr/bin/env bash
# Office31
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_W2A
# Office-Home
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Pr
# VisDA-2017
CUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \
--epochs 30 -i 500 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/erm/VisDA2017_S2R
================================================
FILE: examples/domain_adaptation/openset_domain_adaptation/osbp.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.alignment.osbp import ImageClassifier as Classifier, UnknownClassBinaryCrossEntropy
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy, ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to(device)
print(classifier)
unknown_bce = UnknownClassBinaryCrossEntropy(t=0.5)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = validate(test_loader, classifier, args)
print(acc1)
return
# start training
best_h_score = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, unknown_bce, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
h_score = validate(val_loader, classifier, args)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if h_score > best_h_score:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_h_score = max(h_score, best_h_score)
print("best_h_score = {:3.1f}".format(best_h_score))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
h_score = validate(test_loader, classifier, args)
print("test_h_score = {:3.1f}".format(h_score))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Classifier,
unknown_bce: UnknownClassBinaryCrossEntropy, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
trans_losses = AverageMeter('Trans Loss', ':3.2f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_t, labels_t = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, _ = model(x_s, grad_reverse=False)
y_t, _ = model(x_t, grad_reverse=True)
cls_loss = F.cross_entropy(y_s, labels_s)
trans_loss = unknown_bce(y_t)
loss = cls_loss + trans_loss
cls_acc = accuracy(y_s, labels_s)[0]
tgt_acc = accuracy(y_t, labels_t)[0]
losses.update(loss.item(), x_s.size(0))
trans_losses.update(trans_loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
tgt_accs.update(tgt_acc.item(), x_t.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
def validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float:
batch_time = AverageMeter('Time', ':6.3f')
classes = val_loader.dataset.classes
confmat = ConfusionMatrix(len(classes))
progress = ProgressMeter(
len(val_loader),
[batch_time],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
# measure accuracy and record loss
confmat.update(target, output.argmax(1))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
acc_global, accs, iu = confmat.compute()
all_acc = torch.mean(accs).item() * 100
known = torch.mean(accs[:-1]).item() * 100
unknown = accs[-1].item() * 100
h_score = 2 * known * unknown / (known + unknown)
if args.per_class_eval:
print(confmat.format(classes))
print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}'
.format(all=all_acc, known=known, unknown=unknown, h_score=h_score))
return h_score
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='OSBP for Openset Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='osbp',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/openset_domain_adaptation/osbp.sh
================================================
#!/usr/bin/env bash
# Office31
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_W2A
# Office-Home
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Pr
# VisDA-2017
CUDA_VISIBLE_DEVICES=0 python osbp.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \
--epochs 30 -i 1000 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/osbp/VisDA2017_S2R
================================================
FILE: examples/domain_adaptation/openset_domain_adaptation/utils.py
================================================
import sys
import timm
import torch.nn as nn
import torchvision.transforms as T
sys.path.append('../../..')
import tllib.vision.datasets.openset as datasets
from tllib.vision.datasets.openset import default_open_set as open_set
import tllib.vision.models as models
from tllib.vision.transforms import ResizeImage
def get_model_names():
return sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
) + timm.list_models()
def get_model(model_name):
if model_name in models.__dict__:
# load models from tllib.vision.models
backbone = models.__dict__[model_name](pretrained=True)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=True)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
backbone.copy_head = backbone.get_classifier
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
backbone.copy_head = lambda x: x.head
return backbone
def get_dataset_names():
return sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None):
if train_target_transform is None:
train_target_transform = train_source_transform
# load datasets from tllib.vision.datasets
dataset = datasets.__dict__[dataset_name]
source_dataset = open_set(dataset, source=True)
target_dataset = open_set(dataset, source=False)
train_source_dataset = source_dataset(root=root, task=source, download=True, transform=train_source_transform)
train_target_dataset = target_dataset(root=root, task=target, download=True, transform=train_target_transform)
val_dataset = target_dataset(root=root, task=target, download=True, transform=val_transform)
if dataset_name == 'DomainNet':
test_dataset = target_dataset(root=root, task=target, split='test', download=True, transform=val_transform)
else:
test_dataset = val_dataset
class_names = train_source_dataset.classes
num_classes = len(class_names)
return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names
def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False):
"""
resizing mode:
- default: resize the image to 256 and take a random resized crop of size 224;
- cen.crop: resize the image to 256 and take the center crop of size 224;
- res: resize the image to 224;
- res.|crop: resize the image to 256 and take a random crop of size 224;
- res.sma|crop: resize the image keeping its aspect ratio such that the
smaller side is 256, then take a random crop of size 224;
– inc.crop: “inception crop” from (Szegedy et al., 2015);
– cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224.
"""
if resizing == 'default':
transform = T.Compose([
ResizeImage(256),
T.RandomResizedCrop(224)
])
elif resizing == 'cen.crop':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224)
])
elif resizing == 'ran.crop':
transform = T.Compose([
ResizeImage(256),
T.RandomCrop(224)
])
elif resizing == 'res.':
transform = T.Resize(224)
elif resizing == 'res.|crop':
transform = T.Compose([
T.Resize((256, 256)),
T.RandomCrop(224)
])
elif resizing == "res.sma|crop":
transform = T.Compose([
T.Resize(256),
T.RandomCrop(224)
])
elif resizing == 'inc.crop':
transform = T.RandomResizedCrop(224)
elif resizing == 'cif.crop':
transform = T.Compose([
T.Resize((224, 224)),
T.Pad(28),
T.RandomCrop(224),
])
else:
raise NotImplementedError(resizing)
transforms = [transform]
if random_horizontal_flip:
transforms.append(T.RandomHorizontalFlip())
if random_color_jitter:
transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))
transforms.extend([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return T.Compose(transforms)
def get_val_transform(resizing='default'):
"""
resizing mode:
- default: resize the image to 256 and take the center crop of size 224;
– res.: resize the image to 224
– res.|crop: resize the image such that the smaller side is of size 256 and
then take a central crop of size 224.
"""
if resizing == 'default':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224),
])
elif resizing == 'res.':
transform = T.Resize((224, 224))
elif resizing == 'res.|crop':
transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
])
else:
raise NotImplementedError(resizing)
return T.Compose([
transform,
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/README.md
================================================
# Partial Domain Adaptation for Image Classification
## Installation
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).
You also need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
Following datasets can be downloaded automatically:
- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)
- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)
- [VisDA2017](http://ai.bu.edu/visda-2017/)
## Supported Methods
Supported methods include:
- [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818)
- [Partial Adversarial Domain Adaptation (PADA)](https://arxiv.org/abs/1808.04205)
- [Importance Weighted Adversarial Nets (IWAN)](https://arxiv.org/abs/1803.09210)
- [Adaptive Feature Norm (AFN)](https://arxiv.org/pdf/1811.07456v2.pdf)
## Experiment and Results
The shell files give the script to reproduce the benchmark with specified hyper-parameters.
For example, if you want to train DANN on Office31, use the following script
```shell script
# Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50.
# Assume you have put the datasets under the path `data/office-31`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W
```
**Notations**
- ``Origin`` means the accuracy reported by the original paper.
- ``Avg`` is the accuracy reported by `TLlib`.
- ``ERM`` refers to the model trained with data from the source domain.
- ``Oracle`` refers to the model trained with data from the target domain.
We found that the accuracies of adversarial methods (including DANN) are not stable
even after the random seed is fixed, thus we repeat running adversarial methods on *Office-31* and *VisDA-2017*
for three times and report their average accuracy.
### Office-31 accuracy on ResNet-50
| Methods | Origin | Avg | A → W | D → W | W → D | A → D | D → A | W → A |
|-------------|--------|------|-------|-------|-------|-------|-------|-------|
| ERM | 75.6 | 90.1 | 78.3 | 98.3 | 99.4 | 87.3 | 88.5 | 88.8 | 84.0 |
| DANN | 43.4 | 82.4 | 60.0 | 94.9 | 98.1 | 71.3 | 84.9 | 85.0 |
| PADA | 92.7 | 93.8 | 86.4 | 100.0 | 100.0 | 87.3 | 93.8 | 95.4 |
| IWAN | 94.7 | 94.8 | 91.2 | 99.7 | 99.4 | 89.8 | 94.2 | 94.3 |
| AFN | / | 93.1 | 87.8 | 95.6 | 99.4 | 87.9 | 93.9 | 94.1 |
### Office-Home accuracy on ResNet-50
| Methods | Origin | Avg | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr |
|-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|
| ERM | 53.7 | 60.1 | 42.0 | 66.9 | 78.5 | 56.4 | 55.2 | 65.4 | 57.9 | 36.0 | 75.5 | 68.7 | 43.6 | 74.8 |
| DANN | 47.4 | 57.0 | 46.2 | 59.3 | 76.9 | 47.0 | 47.4 | 56.4 | 51.6 | 38.8 | 72.1 | 68.0 | 46.1 | 74.2 |
| PADA | 62.1 | 65.9 | 52.9 | 69.3 | 82.8 | 59.0 | 57.5 | 66.4 | 66.0 | 41.7 | 82.5 | 78.0 | 50.2 | 84.1 |
| IWAN | 63.6 | 71.3 | 59.2 | 76.6 | 84.0 | 67.8 | 66.7 | 69.2 | 73.3 | 55.0 | 83.9 | 79.0 | 58.3 | 82.2 |
| AFN | 71.8 | 72.6 | 59.2 | 76.7 | 82.8 | 72.5 | 74.5 | 76.8 | 72.5 | 56.7 | 80.8 | 77.0 | 60.5 | 81.6 |
### VisDA-2017 accuracy on ResNet-50
| Methods | Origin | Mean | plane | bcycl | bus | car | horse | knife | Avg |
|-------------|--------|------|-------|-------|------|------|-------|-------|------|
| ERM | 45.3 | 50.9 | 59.2 | 31.3 | 68.7 | 73.2 | 69.3 | 3.4 | 60.0 |
| DANN | 51.0 | 55.9 | 88.4 | 34.1 | 72.1 | 50.7 | 61.9 | 27.8 | 57.1 |
| PADA | 53.5 | 60.5 | 89.4 | 35.1 | 72.5 | 69.2 | 86.7 | 10.1 | 66.8 |
| IWAN | / | 61.5 | 89.2 | 57.0 | 61.5 | 55.2 | 80.1 | 25.7 | 66.8 |
| AFN | 67.6 | 61.0 | 79.1 | 62.7 | 73.9 | 49.6 | 79.6 | 21.0 | 64.1 |
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{DANN,
author = {Ganin, Yaroslav and Lempitsky, Victor},
Booktitle = {ICML},
Title = {Unsupervised domain adaptation by backpropagation},
Year = {2015}
}
@InProceedings{PADA,
author = {Zhangjie Cao and
Lijia Ma and
Mingsheng Long and
Jianmin Wang},
title = {Partial Adversarial Domain Adaptation},
booktitle = {ECCV},
year = {2018}
}
@InProceedings{IWAN,
author = {Jing Zhang and
Zewei Ding and
Wanqing Li and
Philip Ogunbona},
title = {Importance Weighted Adversarial Nets for Partial Domain Adaptation},
booktitle = {CVPR},
year = {2018}
}
@InProceedings{AFN,
author = {Xu, Ruijia and Li, Guanbin and Yang, Jihan and Lin, Liang},
title = {Larger Norm More Transferable: An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation},
booktitle = {ICCV},
year = {2019}
}
```
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/afn.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.normalization.afn import AdaptiveFeatureNorm, ImageClassifier
from tllib.modules.entropy import entropy
import tllib.vision.models as models
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
backbone = models.__dict__[args.arch](pretrained=True)
classifier = ImageClassifier(backbone, train_source_dataset.num_classes, args.num_blocks,
bottleneck_dim=args.bottleneck_dim, dropout_p=args.dropout_p, pool_layer=pool_layer).to(device)
adaptive_feature_norm = AdaptiveFeatureNorm(args.delta).to(device)
# define optimizer
# the learning rate is fixed according to origin paper
optimizer = SGD(classifier.get_parameters(), args.lr, weight_decay=args.weight_decay)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, adaptive_feature_norm, optimizer, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':3.1f')
data_time = AverageMeter('Data', ':3.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
norm_losses = AverageMeter('Norm Loss', ':3.2f')
src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f')
tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs, tgt_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_t, labels_t = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
y_t, f_t = model(x_t)
# classification loss
cls_loss = F.cross_entropy(y_s, labels_s)
# norm loss
norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)
loss = cls_loss + norm_loss * args.trade_off_norm
# using entropy minimization
if args.trade_off_entropy:
y_t = F.softmax(y_t, dim=1)
entropy_loss = entropy(y_t, reduction='mean')
loss += entropy_loss * args.trade_off_entropy
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update statistics
cls_acc = accuracy(y_s, labels_s)[0]
tgt_acc = accuracy(y_t, labels_t)[0]
cls_losses.update(cls_loss.item(), x_s.size(0))
norm_losses.update(norm_loss.item(), x_s.size(0))
src_feature_norm.update(f_s.norm(p=2, dim=1).mean().item(), x_s.size(0))
tgt_feature_norm.update(f_t.norm(p=2, dim=1).mean().item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
tgt_accs.update(tgt_acc.item(), x_s.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='AFN for Partial Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('-n', '--num-blocks', default=1, type=int, help='Number of basic blocks for classifier')
parser.add_argument('--bottleneck-dim', default=1000, type=int, help='Dimension of bottleneck')
parser.add_argument('--dropout-p', default=0.5, type=float,
help='Dropout probability')
# training parameters
parser.add_argument('-b', '--batch-size', default=32, type=int,
metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float,
metavar='W', help='weight decay (default: 5e-4)',
dest='weight_decay')
parser.add_argument('--trade-off-norm', default=0.05, type=float,
help='the trade-off hyper-parameter for norm loss')
parser.add_argument('--trade-off-entropy', default=None, type=float,
help='the trade-off hyper-parameter for entropy loss')
parser.add_argument('-r', '--delta', default=1, type=float, help='Increment for L2 norm')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='afn',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/afn.sh
================================================
#!/usr/bin/env bash
# Office31
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2A
# Office-Home
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Pr
# VisDA-2017
CUDA_VISIBLE_DEVICES=0 python afn.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 -r 0.3 -b 36 \
--epochs 30 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/afn/VisDA2017
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/dann.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.modules.classifier import Classifier
from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
if args.data == 'ImageNetCaltech':
classifier = Classifier(backbone, num_classes, head=backbone.copy_head(), pool_layer=pool_layer).to(device)
else:
classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device)
domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv = DomainAdversarialLoss(domain_discri).to(device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':5.2f')
data_time = AverageMeter('Data', ':5.2f')
losses = AverageMeter('Loss', ':6.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_t, labels_t = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s)
transfer_loss = domain_adv(f_s, f_t)
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
tgt_acc = accuracy(y_t, labels_t)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
tgt_accs.update(tgt_acc.item(), x_s.size(0))
domain_accs.update(domain_acc.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DANN for Partial Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.002, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay',default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='dann',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/dann.sh
================================================
#!/usr/bin/env bash
# Office31
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_W2A
# Office-Home
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Pr
# VisDA-2017
CUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \
--epochs 5 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/dann/VisDA2017_S2R
# ImageNet-Caltech
CUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \
--epochs 5 --seed 0 --log logs/dann/I2C
CUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \
--epochs 5 --seed 0 --bottleneck-dim 2048 --log logs/dann/C2I
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/erm.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.classifier import Classifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, _, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
head = backbone.copy_head() if args.data == 'ImageNetCaltech' else None
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, head=head).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# using shuffled val loader
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(val_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, classifier, optimizer,
lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_s = x_s.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
cls_loss = F.cross_entropy(y_s, labels_s)
loss = cls_loss
cls_acc = accuracy(y_s, labels_s)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Source Only for Partial Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='src_only',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/erm.sh
================================================
#!/usr/bin/env bash
# Office31
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2A
# Office-Home
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Pr
# VisDA-2017
CUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \
--epochs 10 -i 500 --seed 0 --per-class-eval --log logs/erm/VisDA2017_S2R
# ImageNet-Caltech
CUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \
--epochs 20 --seed 0 -i 2000 --log logs/erm/I2C
CUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \
--epochs 20 --seed 0 -i 2000 --log logs/erm/C2I
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/iwan.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.classifier import Classifier
from tllib.modules.entropy import entropy
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.reweight.iwan import ImportanceWeightModule, ImageClassifier
from tllib.alignment.dann import DomainAdversarialLoss
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
if args.data == 'ImageNetCaltech':
classifier = Classifier(backbone, num_classes, head=backbone.copy_head(), pool_layer=pool_layer).to(device)
else:
classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device)
# define domain classifier D, D_0
D = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024, batch_norm=False).to(device)
D_0 = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024, batch_norm=False).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters() + D.get_parameters() + D_0.get_parameters(),
args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv_D = DomainAdversarialLoss(D).to(device)
domain_adv_D_0 = DomainAdversarialLoss(D_0).to(device)
# define importance weight module
importance_weight_module = ImportanceWeightModule(D, train_target_dataset.partial_classes_idx)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, domain_adv_D, domain_adv_D_0,
importance_weight_module, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
domain_adv_D: DomainAdversarialLoss, domain_adv_D_0: DomainAdversarialLoss,
importance_weight_module, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':5.2f')
data_time = AverageMeter('Data', ':5.2f')
losses = AverageMeter('Loss', ':6.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f')
domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f')
partial_classes_weights = AverageMeter('Partial Weight', ':3.2f')
non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs, tgt_accs,
domain_accs_D, domain_accs_D_0, partial_classes_weights, non_partial_classes_weights],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv_D.train()
domain_adv_D_0.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_t, labels_t = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
# classification loss
cls_loss = F.cross_entropy(y_s, labels_s)
# domain adversarial loss for D
adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach())
# get importance weights
w_s = importance_weight_module.get_importance_weight(f_s)
# domain adversarial loss for D_0
adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s)
# entropy loss
y_t = F.softmax(y_t, dim=1)
entropy_loss = entropy(y_t, reduction='mean')
loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \
args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
cls_acc = accuracy(y_s, labels_s)[0]
tgt_acc = accuracy(y_t, labels_t)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
tgt_accs.update(tgt_acc.item(), x_s.size(0))
domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy, x_s.size(0))
domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy, x_s.size(0))
# debug: output class weight averaged on the partial classes and non-partial classes respectively
partial_class_weight, non_partial_classes_weight = \
importance_weight_module.get_partial_classes_weight(w_s, labels_s)
partial_classes_weights.update(partial_class_weight.item(), x_s.size(0))
non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='IWAN for Partial Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of source (and target) dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('--gamma', default=0.1, type=float,
help='the trade-off hyper-parameter for entropy loss(default: 0.1)')
parser.add_argument('--trade-off', default=3, type=float,
help='the trade-off hyper-parameter for transfer loss(default: 3))')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=10, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='iwan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/iwan.sh
================================================
#!/usr/bin/env bash
# Office31
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s A -t W -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s D -t W -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s W -t D -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s A -t D -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s D -t A -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s W -t A -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_W2A
# Office-Home
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Pr
# VisDA-2017
CUDA_VISIBLE_DEVICES=0 python iwan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \
--lr 0.0003 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/iwan/VisDA2017_S2R
# ImageNet-Caltech
CUDA_VISIBLE_DEVICES=0 python iwan.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \
--seed 0 --log logs/iwan/I2C
CUDA_VISIBLE_DEVICES=0 python iwan.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \
--seed 0 --bottleneck-dim 2048 --log logs/iwan/C2I
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/pada.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.modules.classifier import Classifier
from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier
from tllib.reweight.pada import AutomaticUpdateClassWeightModule
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import collect_feature, tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=False)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
if args.data == 'ImageNetCaltech':
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, head=backbone.copy_head()).to(device)
else:
classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device)
domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device)
class_weight_module = AutomaticUpdateClassWeightModule(args.class_weight_update_steps, train_target_loader,
classifier, num_classes, device, args.temperature,
train_target_dataset.partial_classes_idx)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(),
args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
# define loss function
domain_adv = DomainAdversarialLoss(domain_discri).to(device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = collect_feature(train_source_loader, feature_extractor, device)
target_feature = collect_feature(train_target_loader, feature_extractor, device)
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.
for epoch in range(args.epochs):
# train for one epoch
train(train_source_iter, train_target_iter, classifier, domain_adv, class_weight_module,
optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test_acc1 = {:3.1f}".format(acc1))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier,
domain_adv: DomainAdversarialLoss, class_weight_module: AutomaticUpdateClassWeightModule,
optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':5.2f')
data_time = AverageMeter('Data', ':5.2f')
losses = AverageMeter('Loss', ':6.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
tgt_accs = AverageMeter('Tgt Acc', ':3.1f')
partial_classes_weights = AverageMeter('Partial Weight', ':3.1f')
non_partial_classes_weights = AverageMeter('Non-partial Weight', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs, domain_accs, tgt_accs, partial_classes_weights, non_partial_classes_weights],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
domain_adv.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, labels_s = next(train_source_iter)
x_t, labels_t = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
x = torch.cat((x_s, x_t), dim=0)
y, f = model(x)
y_s, y_t = y.chunk(2, dim=0)
f_s, f_t = f.chunk(2, dim=0)
cls_loss = F.cross_entropy(y_s, labels_s, class_weight_module.get_class_weight_for_cross_entropy_loss())
w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss(labels_s)
transfer_loss = domain_adv(f_s, f_t, w_s, w_t)
class_weight_module.step()
partial_classes_weight, non_partial_classes_weight = class_weight_module.get_partial_classes_weight()
domain_acc = domain_adv.domain_discriminator_accuracy
loss = cls_loss + transfer_loss * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
tgt_acc = accuracy(y_t, labels_t)[0]
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
domain_accs.update(domain_acc.item(), x_s.size(0))
tgt_accs.update(tgt_acc.item(), x_s.size(0))
partial_classes_weights.update(partial_classes_weight.item(), x_s.size(0))
non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PADA for Partial Domain Adaptation')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of source (and target) dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: Office31)')
parser.add_argument('-s', '--source', help='source domain')
parser.add_argument('-t', '--target', help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet18)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--bottleneck-dim', default=256, type=int,
help='Dimension of bottleneck')
parser.add_argument('-u', '--class-weight-update-steps', default=500, type=int,
help='Number of steps to update class weight once')
parser.add_argument('--temperature', default=0.1, type=float,
help='temperature for softmax when calculating class weight')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=0.002, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay',default=1e-3, type=float,
metavar='W', help='weight decay (default: 1e-3)',
dest='weight_decay')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument('--per-class-eval', action='store_true',
help='whether output per-class accuracy during evaluation')
parser.add_argument("--log", type=str, default='pada',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/pada.sh
================================================
#!/usr/bin/env bash
# Office31
CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_A2W
CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_D2W
CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_W2D
CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_A2D
CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_D2A
CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_W2A
# Office-Home
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Cl
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Pr
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Rw
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Ar
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Pr
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Rw
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Ar
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Cl
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Rw
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Ar
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Cl
CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Pr
# VisDA-2017
CUDA_VISIBLE_DEVICES=0 python pada.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \
--epochs 20 --seed 0 -u 500 -i 500 --train-resizing cen.crop --trade-off 0.4 --per-class-eval --log logs/pada/VisDA2017_S2R
# ImageNet-Caltech
CUDA_VISIBLE_DEVICES=0 python pada.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \
--epochs 20 --seed 0 --lr 0.003 --temperature 0.01 -u 2000 -i 2000 --log logs/pada/I2C
CUDA_VISIBLE_DEVICES=0 python pada.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \
--epochs 20 --seed 0 --lr 0.003 --temperature 0.01 -u 2000 -i 2000 --bottleneck-dim 2048 --log logs/pada/C2I
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/requirements.txt
================================================
timm
================================================
FILE: examples/domain_adaptation/partial_domain_adaptation/utils.py
================================================
import sys
import time
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
sys.path.append('../../..')
import tllib.vision.datasets.partial as datasets
from tllib.vision.datasets.partial import default_partial as partial
import tllib.vision.models as models
from tllib.vision.transforms import ResizeImage
from tllib.utils.metric import accuracy, ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter
def get_model_names():
return sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
) + timm.list_models()
def get_model(model_name):
if model_name in models.__dict__:
# load models from tllib.vision.models
backbone = models.__dict__[model_name](pretrained=True)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=True)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
backbone.copy_head = backbone.get_classifier
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
backbone.copy_head = lambda x: x.head
return backbone
def get_dataset_names():
return sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None):
if train_target_transform is None:
train_target_transform = train_source_transform
# load datasets from tllib.vision.datasets
dataset = datasets.__dict__[dataset_name]
partial_dataset = partial(dataset)
train_source_dataset = dataset(root=root, task=source, download=True, transform=train_source_transform)
train_target_dataset = partial_dataset(root=root, task=target, download=True, transform=train_target_transform)
val_dataset = partial_dataset(root=root, task=target, download=True, transform=val_transform)
if dataset_name == 'DomainNet':
test_dataset = partial_dataset(root=root, task=target, split='test', download=True, transform=val_transform)
else:
test_dataset = val_dataset
class_names = train_source_dataset.classes
num_classes = len(class_names)
return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names
def validate(val_loader, model, args, device) -> float:
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1],
prefix='Test: ')
# switch to evaluate mode
model.eval()
if args.per_class_eval:
confmat = ConfusionMatrix(len(args.class_names))
else:
confmat = None
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = F.cross_entropy(output, target)
# measure accuracy and record loss
acc1, = accuracy(output, target, topk=(1,))
if confmat:
confmat.update(target, output.argmax(1))
losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
if confmat:
print(confmat.format(args.class_names))
return top1.avg
def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False):
"""
resizing mode:
- default: resize the image to 256 and take a random resized crop of size 224;
- cen.crop: resize the image to 256 and take the center crop of size 224;
- res: resize the image to 224;
- res.|crop: resize the image to 256 and take a random crop of size 224;
- res.sma|crop: resize the image keeping its aspect ratio such that the
smaller side is 256, then take a random crop of size 224;
– inc.crop: “inception crop” from (Szegedy et al., 2015);
– cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224.
"""
if resizing == 'default':
transform = T.Compose([
ResizeImage(256),
T.RandomResizedCrop(224)
])
elif resizing == 'cen.crop':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224)
])
elif resizing == 'res.':
transform = T.Resize(224)
elif resizing == 'res.|crop':
transform = T.Compose([
T.Resize((256, 256)),
T.RandomCrop(224)
])
elif resizing == "res.sma|crop":
transform = T.Compose([
T.Resize(256),
T.RandomCrop(224)
])
elif resizing == 'inc.crop':
transform = T.RandomResizedCrop(224)
elif resizing == 'cif.crop':
transform = T.Compose([
T.Resize((224, 224)),
T.Pad(28),
T.RandomCrop(224),
])
else:
raise NotImplementedError(resizing)
transforms = [transform]
if random_horizontal_flip:
transforms.append(T.RandomHorizontalFlip())
if random_color_jitter:
transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))
transforms.extend([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return T.Compose(transforms)
def get_val_transform(resizing='default'):
"""
resizing mode:
- default: resize the image to 256 and take the center crop of size 224;
– res.: resize the image to 224
– res.|crop: resize the image such that the smaller side is of size 256 and
then take a central crop of size 224.
"""
if resizing == 'default':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224),
])
elif resizing == 'res.':
transform = T.Resize((224, 224))
elif resizing == 'res.|crop':
transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
])
else:
raise NotImplementedError(resizing)
return T.Compose([
transform,
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
================================================
FILE: examples/domain_adaptation/re_identification/README.md
================================================
# Unsupervised Domain Adaptation for Person Re-Identification
## Installation
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You
also need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
Following datasets can be downloaded automatically:
- [Market1501](http://zheng-lab.cecs.anu.edu.au/Project/project_reid.html)
- [DukeMTMC](https://exposing.ai/duke_mtmc/)
- [MSMT17](https://arxiv.org/pdf/1711.08565.pdf)
## Supported Methods
Supported methods include:
- [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)
- [Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification (MMT, 2020 ICLR)](https://arxiv.org/abs/2001.01526)
- [Similarity Preserving Generative Adversarial Network (SPGAN, 2018 CVPR)](https://arxiv.org/pdf/1811.10551.pdf)
## Usage
The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to
train MMT on Market1501 -> DukeMTMC task, use the following script
```shell script
# Train MMT on Market1501 -> DukeMTMC task using ResNet 50.
# Assume you have put the datasets under the path `data/market1501` and `data/dukemtmc`,
# or you are glad to download the datasets automatically from the Internet to this path
# MMT involves two training steps:
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2DukeSeed0
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2DukeSeed1
# step2: train mmt
CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data -t DukeMTMC -a reid_resnet50 \
--pretrained-model-1-path logs/baseline/Market2DukeSeed0/checkpoints/best.pth \
--pretrained-model-2-path logs/baseline/Market2DukeSeed1/checkpoints/best.pth \
--finetune --seed 0 --log logs/mmt/Market2Duke
```
### Experiment and Results
In our experiments, we adopt modified resnet architecture from [MMT](https://arxiv.org/pdf/2001.01526.pdf>). For a fair comparison,
we use standard cross entropy loss and triplet loss in all methods. For methods that utilize clustering algorithms,
we adopt kmeans or DBSCAN and report both results.
**Notations**
- ``Avg`` means the mAP (mean average precision) reported by `TLlib`.
- ``Baseline_Cluster`` represents the strong baseline in [MMT](https://arxiv.org/pdf/2001.01526.pdf>).
### Cross dataset mAP on ResNet-50
| Methods | Avg | Market2Duke | Duke2Market | Market2MSMT | MSMT2Market | Duke2MSMT | MSMT2Duke |
|--------------------------|------|-------------|-------------|-------------|-------------|-----------|-----------|
| Baseline | 27.1 | 32.4 | 31.4 | 8.2 | 36.7 | 11.0 | 43.1 |
| IBN | 30.0 | 35.2 | 36.5 | 11.3 | 38.7 | 14.1 | 44.3 |
| SPGAN | 30.7 | 34.4 | 35.4 | 14.1 | 40.2 | 16.1 | 43.8 |
| Baseline_Cluster(kmeans) | 45.1 | 52.8 | 59.5 | 19.0 | 62.6 | 20.3 | 56.2 |
| Baseline_Cluster(dbscan) | 54.9 | 62.5 | 73.5 | 25.2 | 77.9 | 25.3 | 65.0 |
| MMT(kmeans) | 55.4 | 63.7 | 72.5 | 26.2 | 75.8 | 28.0 | 66.1 |
| MMT(dbscan) | 60.0 | 68.2 | 80.0 | 28.2 | 82.5 | 31.2 | 70.0 |
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{IBN-Net,
author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang},
title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net},
booktitle = {ECCV},
year = {2018}
}
@inproceedings{SPGAN,
title={Image-image domain adaptation with preserved self-similarity and domain-dissimilarity for person re-identification},
author={Deng, Weijian and Zheng, Liang and Ye, Qixiang and Kang, Guoliang and Yang, Yi and Jiao, Jianbin},
booktitle={CVPR},
year={2018}
}
@inproceedings{MMT,
title={Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification},
author={Yixiao Ge and Dapeng Chen and Hongsheng Li},
booktitle={ICLR},
year={2020},
}
```
================================================
FILE: examples/domain_adaptation/re_identification/baseline.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
from torch.nn import DataParallel
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.utils.data import DataLoader
import utils
from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss
from tllib.vision.models.reid.identifier import ReIdentifier
import tllib.vision.datasets.reid as datasets
from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset
from tllib.utils.scheduler import WarmupMultiStepLR
from tllib.utils.metric.reid import validate, visualize_ranked_results
from tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,
random_horizontal_flip=True, random_color_jitter=False,
random_gray_scale=False, random_erasing=False)
val_transform = utils.get_val_transform(args.height, args.width)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
working_dir = osp.dirname(osp.abspath(__file__))
source_root = osp.join(working_dir, args.source_root)
target_root = osp.join(working_dir, args.target_root)
# source dataset
source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower()))
sampler = RandomMultipleGallerySampler(source_dataset.train, args.num_instances)
train_source_loader = DataLoader(
convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform),
batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)
train_source_iter = ForeverDataIterator(train_source_loader)
val_loader = DataLoader(
convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),
root=source_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# target dataset
target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower()))
train_target_loader = DataLoader(
convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=train_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True)
train_target_iter = ForeverDataIterator(train_target_loader)
test_loader = DataLoader(
convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),
root=target_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# create model
num_classes = source_dataset.num_train_pids
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device)
model = DataParallel(model)
# define optimizer and lr scheduler
optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr,
weight_decay=args.weight_decay)
lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1,
warmup_steps=args.warmup_steps)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# plot t-SNE
utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model,
filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)
# visualize ranked results
visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device,
visualize_dir=logger.visualize_directory, width=args.width, height=args.height,
rerank=args.rerank)
return
if args.phase == 'test':
print("Test on source domain:")
validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
print("Test on target domain:")
validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
return
# define loss function
criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)
# start training
best_val_mAP = 0.
best_test_mAP = 0.
for epoch in range(args.epochs):
# print learning rate
print(lr_scheduler.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args)
# update learning rate
lr_scheduler.step()
if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
# evaluate on validation set
print("Validation on source domain...")
_, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device,
cmc_flag=True)
# remember best mAP and save checkpoint
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if val_mAP > best_val_mAP:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_mAP = max(val_mAP, best_val_mAP)
# evaluate on test set
print("Test on target domain...")
_, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,
cmc_flag=True, rerank=args.rerank)
best_test_mAP = max(test_mAP, best_test_mAP)
# evaluate on test set
model.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
print("Test on target domain:")
_, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,
cmc_flag=True, rerank=args.rerank)
print("test mAP on target = {}".format(test_mAP))
print("oracle mAP on target = {}".format(best_test_mAP))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model,
criterion_ce: CrossEntropyLossWithLabelSmooth, criterion_triplet: SoftTripletLoss, optimizer: Adam,
epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_ce = AverageMeter('CeLoss', ':3.2f')
losses_triplet = AverageMeter('TripletLoss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, _, labels_s, _ = next(train_source_iter)
x_t, _, _, _ = next(train_target_iter)
x_s = x_s.to(device)
x_t = x_t.to(device)
labels_s = labels_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, f_s = model(x_s)
y_t, f_t = model(x_t)
# cross entropy loss
loss_ce = criterion_ce(y_s, labels_s)
# triplet loss
loss_triplet = criterion_triplet(f_s, f_s, labels_s)
loss = loss_ce + loss_triplet * args.trade_off
cls_acc = accuracy(y_s, labels_s)[0]
losses_ce.update(loss_ce.item(), x_s.size(0))
losses_triplet.update(loss_triplet.item(), x_s.size(0))
losses.update(loss.item(), x_s.size(0))
cls_accs.update(cls_acc.item(), x_s.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description="Baseline for Domain Adaptative ReID")
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', type=str, help='source domain')
parser.add_argument('-t', '--target', type=str, help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: reid_resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--rate', type=float, default=0.2)
# training parameters
parser.add_argument('--trade-off', type=float, default=1,
help='trade-off hyper parameter between cross entropy loss and triplet loss')
parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=16)
parser.add_argument('--height', type=int, default=256, help="input height")
parser.add_argument('--width', type=int, default=128, help="input width")
parser.add_argument('--num-instances', type=int, default=4,
help="each minibatch consist of "
"(batch_size // num_instances) identities, and "
"each identity has num_instances instances, "
"default: 4")
parser.add_argument('--lr', type=float, default=0.00035,
help="initial learning rate")
parser.add_argument('--weight-decay', type=float, default=5e-4)
parser.add_argument('--epochs', type=int, default=80)
parser.add_argument('--warmup-steps', type=int, default=10, help='number of warm-up steps')
parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70],
help='milestones for the learning rate decay')
parser.add_argument('--eval-step', type=int, default=40)
parser.add_argument('--iters-per-epoch', type=int, default=400)
parser.add_argument('--print-freq', type=int, default=40)
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')
parser.add_argument('--rerank', action='store_true', help="evaluation only")
parser.add_argument("--log", type=str, default='baseline',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/re_identification/baseline.sh
================================================
#!/usr/bin/env bash
# Market1501 -> Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2Duke
# Duke -> Market1501
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2Market
# Market1501 -> MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMT
# MSMT -> Market1501
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Market
# Duke -> MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMT
# MSMT -> Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Duke
================================================
FILE: examples/domain_adaptation/re_identification/baseline_cluster.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import DataParallel
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans, DBSCAN
import utils
import tllib.vision.datasets.reid as datasets
from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset
from tllib.vision.models.reid.identifier import ReIdentifier
from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss
from tllib.utils.metric.reid import extract_reid_feature, validate, visualize_ranked_results
from tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,
random_horizontal_flip=True, random_color_jitter=False,
random_gray_scale=False, random_erasing=True)
val_transform = utils.get_val_transform(args.height, args.width)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
working_dir = osp.dirname(osp.abspath(__file__))
source_root = osp.join(working_dir, args.source_root)
target_root = osp.join(working_dir, args.target_root)
# source dataset
source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower()))
val_loader = DataLoader(
convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),
root=source_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# target dataset
target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower()))
cluster_loader = DataLoader(
convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
test_loader = DataLoader(
convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),
root=target_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# create model
num_classes = args.num_clusters
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device)
model = DataParallel(model)
# load pretrained weights
pretrained_model = torch.load(args.pretrained_model_path)
utils.copy_state_dict(model, pretrained_model)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
utils.copy_state_dict(model, checkpoint['model'])
# analysis the model
if args.phase == 'analysis':
# plot t-SNE
utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model,
filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)
# visualize ranked results
visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device,
visualize_dir=logger.visualize_directory, width=args.width, height=args.height,
rerank=args.rerank)
return
if args.phase == 'test':
print("Test on Source domain:")
validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
print("Test on target domain:")
validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
return
# define loss function
criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)
# optionally resume from a checkpoint
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
utils.copy_state_dict(model, checkpoint['model'])
args.start_epoch = checkpoint['epoch'] + 1
# start training
best_test_mAP = 0.
for epoch in range(args.start_epoch, args.epochs):
# run clustering algorithm and generate pseudo labels
if args.clustering_algorithm == 'kmeans':
train_target_iter = run_kmeans(cluster_loader, model, target_dataset, train_transform, args)
elif args.clustering_algorithm == 'dbscan':
train_target_iter, num_classes = run_dbscan(cluster_loader, model, target_dataset, train_transform, args)
# define cross entropy loss with current number of classes
criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
# define optimizer
optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr,
weight_decay=args.weight_decay)
# train for one epoch
train(train_target_iter, model, optimizer, criterion_ce, criterion_triplet, epoch, args)
if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
# remember best mAP and save checkpoint
torch.save(
{
'model': model.state_dict(),
'epoch': epoch
}, logger.get_checkpoint_path(epoch)
)
print("Test on target domain...")
_, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,
cmc_flag=True, rerank=args.rerank)
if test_mAP > best_test_mAP:
shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))
best_test_mAP = max(test_mAP, best_test_mAP)
print("best mAP on target = {}".format(best_test_mAP))
logger.close()
def run_kmeans(cluster_loader: DataLoader, model: DataParallel, target_dataset, train_transform,
args: argparse.Namespace):
# run kmeans clustering algorithm
print('Clustering into {} classes'.format(args.num_clusters))
feature_dict = extract_reid_feature(cluster_loader, model, device, normalize=True)
feature = torch.stack(list(feature_dict.values())).cpu().numpy()
km = KMeans(n_clusters=args.num_clusters, random_state=args.seed).fit(feature)
cluster_labels = km.labels_
cluster_centers = km.cluster_centers_
print('Clustering finished')
# normalize cluster centers and convert to pytorch tensor
cluster_centers = torch.from_numpy(cluster_centers).float().to(device)
cluster_centers = F.normalize(cluster_centers, dim=1)
# reinitialize classifier head
model.module.head.weight.data.copy_(cluster_centers)
# generate training set with pseudo labels
target_train_set = []
for (fname, _, cid), label in zip(target_dataset.train, cluster_labels):
target_train_set.append((fname, int(label), cid))
sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances)
train_target_loader = DataLoader(
convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir, transform=train_transform),
batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)
train_target_iter = ForeverDataIterator(train_target_loader)
return train_target_iter
def run_dbscan(cluster_loader: DataLoader, model: DataParallel, target_dataset, train_transform,
args: argparse.Namespace):
# run dbscan clustering algorithm
feature_dict = extract_reid_feature(cluster_loader, model, device, normalize=True)
feature = torch.stack(list(feature_dict.values())).cpu()
rerank_dist = utils.compute_rerank_dist(feature).numpy()
print('Clustering with dbscan algorithm')
dbscan = DBSCAN(eps=0.6, min_samples=4, metric='precomputed', n_jobs=-1)
cluster_labels = dbscan.fit_predict(rerank_dist)
print('Clustering finished')
# generate training set with pseudo labels and calculate cluster centers
target_train_set = []
cluster_centers = {}
for i, ((fname, _, cid), label) in enumerate(zip(target_dataset.train, cluster_labels)):
if label == -1:
continue
target_train_set.append((fname, label, cid))
if label not in cluster_centers:
cluster_centers[label] = []
cluster_centers[label].append(feature[i])
cluster_centers = [torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys())]
cluster_centers = torch.stack(cluster_centers)
# normalize cluster centers
cluster_centers = F.normalize(cluster_centers, dim=1).float().to(device)
# reinitialize classifier head
features_dim = model.module.features_dim
num_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
model.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)
model.module.head.weight.data.copy_(cluster_centers)
sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances)
train_target_loader = DataLoader(
convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir, transform=train_transform),
batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)
train_target_iter = ForeverDataIterator(train_target_loader)
return train_target_iter, num_clusters
def train(train_target_iter: ForeverDataIterator, model, optimizer, criterion_ce: CrossEntropyLossWithLabelSmooth,
criterion_triplet: SoftTripletLoss, epoch: int, args: argparse.Namespace):
# train with pseudo labels
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_ce = AverageMeter('CeLoss', ':3.2f')
losses_triplet = AverageMeter('TripletLoss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_t, _, labels_t, _ = next(train_target_iter)
x_t = x_t.to(device)
labels_t = labels_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_t, f_t = model(x_t)
# cross entropy loss
loss_ce = criterion_ce(y_t, labels_t)
# triplet loss
loss_triplet = criterion_triplet(f_t, f_t, labels_t)
loss = loss_ce + loss_triplet * args.trade_off
cls_acc = accuracy(y_t, labels_t)[0]
losses_ce.update(loss_ce.item(), x_t.size(0))
losses_triplet.update(loss_triplet.item(), x_t.size(0))
losses.update(loss.item(), x_t.size(0))
cls_accs.update(cls_acc.item(), x_t.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description="Cluster Baseline for Domain Adaptative ReID")
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', type=str, help='source domain')
parser.add_argument('-t', '--target', type=str, help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: reid_resnet50)')
parser.add_argument('--num-clusters', type=int, default=500)
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--rate', type=float, default=0.2)
# training parameters
parser.add_argument('--clustering-algorithm', type=str, default='dbscan', choices=['kmeans', 'dbscan'],
help='clustering algorithm to run, currently supported method: ["kmeans", "dbscan"]')
parser.add_argument('--resume', type=str, default=None,
help="Where restore model parameters from.")
parser.add_argument('--pretrained-model-path', type=str, help='path to pretrained (source-only) model')
parser.add_argument('--trade-off', type=float, default=1,
help='trade-off hyper parameter between cross entropy loss and triplet loss')
parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=64)
parser.add_argument('--height', type=int, default=256, help="input height")
parser.add_argument('--width', type=int, default=128, help="input width")
parser.add_argument('--num-instances', type=int, default=4,
help="each minibatch consist of "
"(batch_size // num_instances) identities, and "
"each identity has num_instances instances, "
"default: 4")
parser.add_argument('--lr', type=float, default=0.00035,
help="learning rate")
parser.add_argument('--weight-decay', type=float, default=5e-4)
parser.add_argument('--epochs', type=int, default=40)
parser.add_argument('--start-epoch', default=0, type=int, help='start epoch')
parser.add_argument('--eval-step', type=int, default=1)
parser.add_argument('--iters-per-epoch', type=int, default=400)
parser.add_argument('--print-freq', type=int, default=40)
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')
parser.add_argument('--rerank', action='store_true', help="evaluation only")
parser.add_argument("--log", type=str, default='baseline_cluster',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/re_identification/baseline_cluster.sh
================================================
#!/usr/bin/env bash
# Market1501 -> Duke
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2Duke
# step2: train with pseudo labels assigned by cluster algorithm
CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--pretrained-model-path logs/baseline/Market2Duke/checkpoints/best.pth \
--finetune --seed 0 --log logs/baseline_cluster/Market2Duke
# Duke -> Market1501
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2Market
# step2: train with pseudo labels assigned by cluster algorithm
CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \
--pretrained-model-path logs/baseline/Duke2Market/checkpoints/best.pth \
--finetune --seed 0 --log logs/baseline_cluster/Duke2Market
# Market1501 -> MSMT
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMT
# step2: train with pseudo labels assigned by cluster algorithm
CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \
--pretrained-model-path logs/baseline/Market2MSMT/checkpoints/best.pth \
--num-clusters 1000 --finetune --seed 0 --log logs/baseline_cluster/Market2MSMT
# MSMT -> Market1501
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Market
# step2: train with pseudo labels assigned by cluster algorithm
CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \
--pretrained-model-path logs/baseline/MSMT2Market/checkpoints/best.pth \
--finetune --seed 0 --log logs/baseline_cluster/MSMT2Market
# Duke -> MSMT
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMT
# step2: train with pseudo labels assigned by cluster algorithm
CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \
--pretrained-model-path logs/baseline/Duke2MSMT/checkpoints/best.pth \
--num-clusters 1000 --finetune --seed 0 --log logs/baseline_cluster/Duke2MSMT
# MSMT -> Duke
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Duke
# step2: train with pseudo labels assigned by cluster algorithm
CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \
--pretrained-model-path logs/baseline/MSMT2Duke/checkpoints/best.pth \
--finetune --seed 0 --log logs/baseline_cluster/MSMT2Duke
================================================
FILE: examples/domain_adaptation/re_identification/ibn.sh
================================================
#!/usr/bin/env bash
# Market1501 -> Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a resnet50_ibn_a \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a resnet50_ibn_b \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2Duke
# Duke -> Market1501
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a resnet50_ibn_a \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2Market
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a resnet50_ibn_b \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2Market
# Market1501 -> MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a resnet50_ibn_a \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a resnet50_ibn_b \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2MSMT
# MSMT -> Market1501
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a resnet50_ibn_a \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Market
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a resnet50_ibn_b \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Market
# Duke -> MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a resnet50_ibn_a \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a resnet50_ibn_b \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2MSMT
# MSMT -> Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a resnet50_ibn_a \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a resnet50_ibn_b \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Duke
================================================
FILE: examples/domain_adaptation/re_identification/mmt.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import DataParallel
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans, DBSCAN
import utils
import tllib.vision.datasets.reid as datasets
from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset
from tllib.vision.models.reid.identifier import ReIdentifier
from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss, CrossEntropyLoss
from tllib.self_training.mean_teacher import EMATeacher
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric.reid import extract_reid_feature, validate, visualize_ranked_results
from tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,
random_horizontal_flip=True, random_color_jitter=False,
random_gray_scale=False, random_erasing=True)
val_transform = utils.get_val_transform(args.height, args.width)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
working_dir = osp.dirname(osp.abspath(__file__))
source_root = osp.join(working_dir, args.source_root)
target_root = osp.join(working_dir, args.target_root)
# source dataset
source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower()))
val_loader = DataLoader(
convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),
root=source_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# target dataset
target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower()))
cluster_loader = DataLoader(
convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
test_loader = DataLoader(
convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),
root=target_dataset.images_dir, transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# create model
model_1, model_1_ema = create_model(args, args.pretrained_model_1_path)
model_2, model_2_ema = create_model(args, args.pretrained_model_2_path)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
utils.copy_state_dict(model_1_ema, checkpoint)
# analysis the model
if args.phase == 'analysis':
# plot t-SNE
utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model_1_ema,
filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)
# visualize ranked results
visualize_ranked_results(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery, device,
visualize_dir=logger.visualize_directory, width=args.width, height=args.height,
rerank=args.rerank)
return
if args.phase == 'test':
print("Test on Source domain:")
validate(val_loader, model_1_ema, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
print("Test on target domain:")
validate(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
return
# define loss function
num_classes = args.num_clusters
criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
criterion_ce_soft = CrossEntropyLoss().to(device)
criterion_triplet = SoftTripletLoss(margin=0.0).to(device)
criterion_triplet_soft = SoftTripletLoss(margin=None).to(device)
# optionally resume from a checkpoint
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
utils.copy_state_dict(model_1, checkpoint['model_1'])
utils.copy_state_dict(model_1_ema, checkpoint['model_1_ema'])
utils.copy_state_dict(model_2, checkpoint['model_2'])
utils.copy_state_dict(model_2_ema, checkpoint['model_2_ema'])
args.start_epoch = checkpoint['epoch'] + 1
# start training
best_test_mAP = 0.
for epoch in range(args.start_epoch, args.epochs):
# run clustering algorithm and generate pseudo labels
if args.clustering_algorithm == 'kmeans':
train_target_iter = run_kmeans(cluster_loader, model_1, model_2, model_1_ema, model_2_ema, target_dataset,
train_transform, args)
elif args.clustering_algorithm == 'dbscan':
train_target_iter, num_classes = run_dbscan(cluster_loader, model_1, model_2, model_1_ema, model_2_ema,
target_dataset, train_transform, args)
# define cross entropy loss with current number of classes
criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
# define optimizer
optimizer = Adam(model_1.module.get_parameters(base_lr=args.lr, rate=args.rate) + model_2.module.get_parameters(
base_lr=args.lr, rate=args.rate), args.lr, weight_decay=args.weight_decay)
# train for one epoch
train(train_target_iter, model_1, model_1_ema, model_2, model_2_ema, optimizer, criterion_ce, criterion_ce_soft,
criterion_triplet, criterion_triplet_soft, epoch, args)
if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
# save checkpoint and remember best mAP
torch.save(
{
'model_1': model_1.state_dict(),
'model_1_ema': model_1_ema.state_dict(),
'model_2': model_2.state_dict(),
'model_2_ema': model_2_ema.state_dict(),
'epoch': epoch
}, logger.get_checkpoint_path(epoch)
)
print("Test model_1 on target domain...")
_, test_mAP_1 = validate(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery,
device, cmc_flag=True, rerank=args.rerank)
print("Test model_2 on target domain...")
_, test_mAP_2 = validate(test_loader, model_2_ema, target_dataset.query, target_dataset.gallery,
device, cmc_flag=True, rerank=args.rerank)
if test_mAP_1 > test_mAP_2 and test_mAP_1 > best_test_mAP:
torch.save(model_1_ema.state_dict(), logger.get_checkpoint_path('best'))
best_test_mAP = test_mAP_1
if test_mAP_2 > test_mAP_1 and test_mAP_2 > best_test_mAP:
torch.save(model_2_ema.state_dict(), logger.get_checkpoint_path('best'))
best_test_mAP = test_mAP_2
print("best mAP on target = {}".format(best_test_mAP))
logger.close()
def create_model(args: argparse.Namespace, pretrained_model_path: str):
num_classes = args.num_clusters
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device)
model = DataParallel(model)
# load pretrained weights
pretrained_model = torch.load(pretrained_model_path)
utils.copy_state_dict(model, pretrained_model)
# EMA model
model_ema = EMATeacher(model, args.alpha)
return model, model_ema
def run_kmeans(cluster_loader: DataLoader, model_1: DataParallel, model_2: DataParallel, model_1_ema: EMATeacher,
model_2_ema: EMATeacher, target_dataset, train_transform, args: argparse.Namespace):
# run kmeans clustering algorithm
print('Clustering into {} classes'.format(args.num_clusters))
# collect feature with different ema teachers
feature_dict_1 = extract_reid_feature(cluster_loader, model_1_ema, device, normalize=True)
feature_1 = torch.stack(list(feature_dict_1.values())).cpu().numpy()
feature_dict_2 = extract_reid_feature(cluster_loader, model_2_ema, device, normalize=True)
feature_2 = torch.stack(list(feature_dict_2.values())).cpu().numpy()
# average feature_1, feature_2 to create final feature
feature = (feature_1 + feature_2) / 2
km = KMeans(n_clusters=args.num_clusters, random_state=args.seed).fit(feature)
cluster_labels = km.labels_
cluster_centers = km.cluster_centers_
print('Clustering finished')
# normalize cluster centers and convert to pytorch tensor
cluster_centers = torch.from_numpy(cluster_centers).float().to(device)
cluster_centers = F.normalize(cluster_centers, dim=1)
# reinitialize classifier head
model_1.module.head.weight.data.copy_(cluster_centers)
model_2.module.head.weight.data.copy_(cluster_centers)
model_1_ema.module.head.weight.data.copy_(cluster_centers)
model_2_ema.module.head.weight.data.copy_(cluster_centers)
# generate training set with pseudo labels
target_train_set = []
for (fname, _, cid), label in zip(target_dataset.train, cluster_labels):
target_train_set.append((fname, int(label), cid))
sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances)
train_target_loader = DataLoader(
convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir,
transform=MultipleApply([train_transform, train_transform])),
batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)
train_target_iter = ForeverDataIterator(train_target_loader)
return train_target_iter
def run_dbscan(cluster_loader: DataLoader, model_1: DataParallel, model_2: DataParallel, model_1_ema: EMATeacher,
model_2_ema: EMATeacher, target_dataset, train_transform, args: argparse.Namespace):
# run dbscan clustering algorithm
# collect feature with different ema teachers
feature_dict_1 = extract_reid_feature(cluster_loader, model_1_ema, device, normalize=True)
feature_1 = torch.stack(list(feature_dict_1.values())).cpu()
feature_dict_2 = extract_reid_feature(cluster_loader, model_2_ema, device, normalize=True)
feature_2 = torch.stack(list(feature_dict_2.values())).cpu()
# average feature_1, feature_2 to create final feature
feature = (feature_1 + feature_2) / 2
feature = F.normalize(feature, dim=1)
rerank_dist = utils.compute_rerank_dist(feature).numpy()
print('Clustering with dbscan algorithm')
dbscan = DBSCAN(eps=0.7, min_samples=4, metric='precomputed', n_jobs=-1)
cluster_labels = dbscan.fit_predict(rerank_dist)
print('Clustering finished')
# generate training set with pseudo labels and calculate cluster centers
target_train_set = []
cluster_centers = {}
for i, ((fname, _, cid), label) in enumerate(zip(target_dataset.train, cluster_labels)):
if label == -1:
continue
target_train_set.append((fname, label, cid))
if label not in cluster_centers:
cluster_centers[label] = []
cluster_centers[label].append(feature[i])
cluster_centers = [torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys())]
cluster_centers = torch.stack(cluster_centers)
# normalize cluster centers
cluster_centers = F.normalize(cluster_centers, dim=1).float().to(device)
# reinitialize classifier head
features_dim = model_1.module.features_dim
num_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
model_1.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)
model_2.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)
model_1_ema.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)
model_2_ema.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device)
model_1.module.head.weight.data.copy_(cluster_centers)
model_2.module.head.weight.data.copy_(cluster_centers)
model_1_ema.module.head.weight.data.copy_(cluster_centers)
model_2_ema.module.head.weight.data.copy_(cluster_centers)
sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances)
train_target_loader = DataLoader(
convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir,
transform=MultipleApply([train_transform, train_transform])),
batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)
train_target_iter = ForeverDataIterator(train_target_loader)
return train_target_iter, num_clusters
def train(train_target_iter: ForeverDataIterator, model_1: DataParallel, model_1_ema: EMATeacher, model_2: DataParallel,
model_2_ema: EMATeacher, optimizer: Adam, criterion_ce: CrossEntropyLossWithLabelSmooth,
criterion_ce_soft: CrossEntropyLoss, criterion_triplet: SoftTripletLoss,
criterion_triplet_soft: SoftTripletLoss, epoch: int, args: argparse.Namespace):
# train with pseudo labels
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
# statistics for model_1
losses_ce_1 = AverageMeter('Model_1 CELoss', ':3.2f')
losses_triplet_1 = AverageMeter('Model_1 TripletLoss', ':3.2f')
cls_accs_1 = AverageMeter('Model_1 Cls Acc', ':3.1f')
# statistics for model_2
losses_ce_2 = AverageMeter('Model_2 CELoss', ':3.2f')
losses_triplet_2 = AverageMeter('Model_2 TripletLoss', ':3.2f')
cls_accs_2 = AverageMeter('Model_2 Cls Acc', ':3.1f')
losses_ce_soft = AverageMeter('Soft CELoss', ':3.2f')
losses_triplet_soft = AverageMeter('Soft TripletLoss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_ce_1, losses_triplet_1, cls_accs_1, losses_ce_2, losses_triplet_2, cls_accs_2,
losses_ce_soft, losses_triplet_soft, losses],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model_1.train()
model_2.train()
model_1_ema.train()
model_2_ema.train()
end = time.time()
for i in range(args.iters_per_epoch):
# below we ignore subscript `t` and use `x_1`, `x_2` to denote different augmented versions of origin samples
# `x_t` from target domain
(x_1, x_2), _, labels, _ = next(train_target_iter)
x_1 = x_1.to(device)
x_2 = x_2.to(device)
labels = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_1, f_1 = model_1(x_1)
y_2, f_2 = model_2(x_2)
# compute output by ema-teacher
y_1_teacher, f_1_teacher = model_1_ema(x_1)
y_2_teacher, f_2_teacher = model_2_ema(x_2)
# cross entropy loss
loss_ce_1 = criterion_ce(y_1, labels)
loss_ce_2 = criterion_ce(y_2, labels)
# triplet loss
loss_triplet_1 = criterion_triplet(f_1, f_1, labels)
loss_triplet_2 = criterion_triplet(f_2, f_2, labels)
# soft cross entropy loss
loss_ce_soft = criterion_ce_soft(y_1, y_2_teacher) + \
criterion_ce_soft(y_2, y_1_teacher)
# soft triplet loss
loss_triplet_soft = criterion_triplet_soft(f_1, f_2_teacher, labels) + \
criterion_triplet_soft(f_2, f_1_teacher, labels)
# final objective
loss = (loss_ce_1 + loss_ce_2) * (1 - args.trade_off_ce_soft) + \
(loss_triplet_1 + loss_triplet_2) * (1 - args.trade_off_triplet_soft) + \
loss_ce_soft * args.trade_off_ce_soft + \
loss_triplet_soft * args.trade_off_triplet_soft
# update statistics
batch_size = args.batch_size
cls_acc_1 = accuracy(y_1, labels)[0]
cls_acc_2 = accuracy(y_2, labels)[0]
# model 1
losses_ce_1.update(loss_ce_1.item(), batch_size)
losses_triplet_1.update(loss_triplet_1.item(), batch_size)
cls_accs_1.update(cls_acc_1.item(), batch_size)
# model 2
losses_ce_2.update(loss_ce_2.item(), batch_size)
losses_triplet_2.update(loss_triplet_2.item(), batch_size)
cls_accs_2.update(cls_acc_2.item(), batch_size)
losses_ce_soft.update(loss_ce_soft.item(), batch_size)
losses_triplet_soft.update(loss_triplet_soft.item(), batch_size)
losses.update(loss.item(), batch_size)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# update teacher
global_step = epoch * args.iters_per_epoch + i + 1
model_1_ema.set_alpha(min(args.alpha, 1 - 1 / global_step))
model_2_ema.set_alpha(min(args.alpha, 1 - 1 / global_step))
model_1_ema.update()
model_2_ema.update()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description="MMT for Domain Adaptative ReID")
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', type=str, help='source domain')
parser.add_argument('-t', '--target', type=str, help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: reid_resnet50)')
parser.add_argument('--num-clusters', type=int, default=500)
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--alpha', type=float, default=0.999, help='ema alpha')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--rate', type=float, default=0.2)
# training parameters
parser.add_argument('--clustering-algorithm', type=str, default='dbscan', choices=['kmeans', 'dbscan'],
help='clustering algorithm to run, currently supported method: ["kmeans", "dbscan"]')
parser.add_argument('--resume', type=str, default=None,
help="Where restore model parameters from.")
parser.add_argument('--pretrained-model-1-path', type=str, help='path to pretrained (source-only) model_1')
parser.add_argument('--pretrained-model-2-path', type=str, help='path to pretrained (source-only) model_2')
parser.add_argument('--trade-off-ce-soft', type=float, default=0.5,
help='the trade off hyper parameter between cross entropy loss and soft cross entropy loss')
parser.add_argument('--trade-off-triplet-soft', type=float, default=0.8,
help='the trade off hyper parameter between triplet loss and soft triplet loss')
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=64)
parser.add_argument('--height', type=int, default=256, help="input height")
parser.add_argument('--width', type=int, default=128, help="input width")
parser.add_argument('--num-instances', type=int, default=4,
help="each minibatch consist of "
"(batch_size // num_instances) identities, and "
"each identity has num_instances instances, "
"default: 4")
parser.add_argument('--lr', type=float, default=0.00035,
help="learning rate")
parser.add_argument('--weight-decay', type=float, default=5e-4)
parser.add_argument('--epochs', type=int, default=40)
parser.add_argument('--start-epoch', default=0, type=int, help='start epoch')
parser.add_argument('--eval-step', type=int, default=1)
parser.add_argument('--iters-per-epoch', type=int, default=400)
parser.add_argument('--print-freq', type=int, default=40)
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')
parser.add_argument('--rerank', action='store_true', help="evaluation only")
parser.add_argument("--log", type=str, default='mmt',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/re_identification/mmt.sh
================================================
#!/usr/bin/env bash
# Market1501 -> Duke
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2DukeSeed0
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2DukeSeed1
# step2: train mmt
CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--pretrained-model-1-path logs/baseline/Market2DukeSeed0/checkpoints/best.pth \
--pretrained-model-2-path logs/baseline/Market2DukeSeed1/checkpoints/best.pth \
--finetune --seed 0 --log logs/mmt/Market2Duke
# Duke -> Market1501
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MarketSeed0
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Duke2MarketSeed1
# step2: train mmt
CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \
--pretrained-model-1-path logs/baseline/Duke2MarketSeed0/checkpoints/best.pth \
--pretrained-model-2-path logs/baseline/Duke2MarketSeed1/checkpoints/best.pth \
--finetune --seed 0 --log logs/mmt/Duke2Market
# Market1501 -> MSMT
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMTSeed0
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2MSMTSeed1
# step2: train mmt
CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \
--pretrained-model-1-path logs/baseline/Market2MSMTSeed0/checkpoints/best.pth \
--pretrained-model-2-path logs/baseline/Market2MSMTSeed1/checkpoints/best.pth \
--num-clusters 1000 --finetune --seed 0 --log logs/mmt/Market2MSMT
# MSMT -> Market1501
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2MarketSeed0
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/MSMT2MarketSeed1
# step2: train mmt
CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \
--pretrained-model-1-path logs/baseline/MSMT2MarketSeed0/checkpoints/best.pth \
--pretrained-model-2-path logs/baseline/MSMT2MarketSeed1/checkpoints/best.pth \
--finetune --seed 0 --log logs/mmt/MSMT2Market
# Duke -> MSMT
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMTSeed0
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Duke2MSMTSeed1
# step2: train mmt
CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \
--pretrained-model-1-path logs/baseline/Duke2MSMTSeed0/checkpoints/best.pth \
--pretrained-model-2-path logs/baseline/Duke2MSMTSeed1/checkpoints/best.pth \
--num-clusters 1000 --finetune --seed 0 --log logs/mmt/Duke2MSMT
# MSMT -> Duke
# step1: pretrain
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2DukeSeed0
CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/MSMT2DukeSeed1
# step2: train mmt
CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \
--pretrained-model-1-path logs/baseline/MSMT2DukeSeed0/checkpoints/best.pth \
--pretrained-model-2-path logs/baseline/MSMT2DukeSeed1/checkpoints/best.pth \
--finetune --seed 0 --log logs/mmt/MSMT2Duke
================================================
FILE: examples/domain_adaptation/re_identification/requirements.txt
================================================
timm
opencv-python
================================================
FILE: examples/domain_adaptation/re_identification/spgan.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import itertools
import os.path as osp
from PIL import Image
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import torchvision.transforms as T
sys.path.append('../../..')
import tllib.translation.cyclegan as cyclegan
import tllib.translation.spgan as spgan
from tllib.translation.cyclegan.util import ImagePool, set_requires_grad
import tllib.vision.datasets.reid as datasets
from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset
from tllib.vision.transforms import Denormalize
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = T.Compose([
T.Resize(args.load_size, Image.BICUBIC),
T.RandomCrop(args.input_size),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
working_dir = osp.dirname(osp.abspath(__file__))
root = osp.join(working_dir, args.root)
source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower()))
train_source_loader = DataLoader(
convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True)
target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower()))
train_target_loader = DataLoader(
convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=train_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# define networks (generators, discriminators and siamese network)
netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)
netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)
siamese_net = spgan.SiameseNetwork(nsf=args.nsf).to(device)
# create image buffer to store previously generated images
fake_S_pool = ImagePool(args.pool_size)
fake_T_pool = ImagePool(args.pool_size)
# define optimizer and lr scheduler
optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr,
betas=(args.beta1, 0.999))
optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999))
optimizer_siamese = Adam(siamese_net.parameters(), lr=args.lr, betas=(args.beta1, 0.999))
lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay)
lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)
lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)
lr_scheduler_siamese = LambdaLR(optimizer_siamese, lr_lambda=lr_decay_function)
# optionally resume from a checkpoint
if args.resume:
print("Resume from", args.resume)
checkpoint = torch.load(args.resume, map_location='cpu')
netG_S2T.load_state_dict(checkpoint['netG_S2T'])
netG_T2S.load_state_dict(checkpoint['netG_T2S'])
netD_S.load_state_dict(checkpoint['netD_S'])
netD_T.load_state_dict(checkpoint['netD_T'])
siamese_net.load_state_dict(checkpoint['siamese_net'])
optimizer_G.load_state_dict(checkpoint['optimizer_G'])
optimizer_D.load_state_dict(checkpoint['optimizer_D'])
optimizer_siamese.load_state_dict(checkpoint['optimizer_siamese'])
lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])
lr_scheduler_siamese.load_state_dict(checkpoint['lr_scheduler_siamese'])
args.start_epoch = checkpoint['epoch'] + 1
if args.phase == 'test':
transform = T.Compose([
T.Resize(args.test_input_size, Image.BICUBIC),
cyclegan.transform.Translation(netG_S2T, device)
])
source_dataset.translate(transform, osp.join(args.translated_root, args.source.lower()))
return
# define loss function
criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_contrastive = spgan.ContrastiveLoss(margin=args.margin)
# define visualization function
tensor_to_image = T.Compose([
Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
T.ToPILImage()
])
def visualize(image, name):
"""
Args:
image (tensor): image in shape 3 x H x W
name: name of the saving image
"""
tensor_to_image(image).save(logger.get_image_path("{}.png".format(name)))
# start training
for epoch in range(args.start_epoch, args.epochs + args.epochs_decay):
logger.set_epoch(epoch)
print(lr_scheduler_G.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, siamese_net,
criterion_gan, criterion_cycle, criterion_identity, criterion_contrastive,
optimizer_G, optimizer_D, optimizer_siamese,
fake_S_pool, fake_T_pool, epoch, visualize, args)
# update learning rates
lr_scheduler_G.step()
lr_scheduler_D.step()
lr_scheduler_siamese.step()
# save checkpoint
torch.save(
{
'netG_S2T': netG_S2T.state_dict(),
'netG_T2S': netG_T2S.state_dict(),
'netD_S': netD_S.state_dict(),
'netD_T': netD_T.state_dict(),
'siamese_net': siamese_net.state_dict(),
'optimizer_G': optimizer_G.state_dict(),
'optimizer_D': optimizer_D.state_dict(),
'optimizer_siamese': optimizer_siamese.state_dict(),
'lr_scheduler_G': lr_scheduler_G.state_dict(),
'lr_scheduler_D': lr_scheduler_D.state_dict(),
'lr_scheduler_siamese': lr_scheduler_siamese.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if args.translated_root is not None:
transform = T.Compose([
T.Resize(args.test_input_size, Image.BICUBIC),
cyclegan.transform.Translation(netG_S2T, device)
])
source_dataset.translate(transform, osp.join(args.translated_root, args.source.lower()))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
netG_S2T, netG_T2S, netD_S, netD_T, siamese_net: spgan.SiameseNetwork,
criterion_gan: cyclegan.LeastSquaresGenerativeAdversarialLoss,
criterion_cycle: nn.L1Loss, criterion_identity: nn.L1Loss,
criterion_contrastive: spgan.ContrastiveLoss,
optimizer_G: Adam, optimizer_D: Adam, optimizer_siamese: Adam,
fake_S_pool: ImagePool, fake_T_pool: ImagePool, epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
losses_D_S = AverageMeter('D_S', ':3.2f')
losses_D_T = AverageMeter('D_T', ':3.2f')
losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
losses_identity_S = AverageMeter('idt_S', ':3.2f')
losses_identity_T = AverageMeter('idt_T', ':3.2f')
losses_contrastive_G = AverageMeter('contrastive_G', ':3.2f')
losses_contrastive_siamese = AverageMeter('contrastive_siamese', ':3.2f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T,
losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T,
losses_contrastive_G, losses_contrastive_siamese],
prefix="Epoch: [{}]".format(epoch))
end = time.time()
for i in range(args.iters_per_epoch):
real_S, _, _, _ = next(train_source_iter)
real_T, _, _, _ = next(train_target_iter)
real_S = real_S.to(device)
real_T = real_T.to(device)
# measure data loading time
data_time.update(time.time() - end)
# Compute fake images and reconstruction images.
fake_T = netG_S2T(real_S)
rec_S = netG_T2S(fake_T)
fake_S = netG_T2S(real_T)
rec_T = netG_S2T(fake_S)
# ===============================================
# train the generators (every two iterations)
# ===============================================
if i % 2 == 0:
# save memory
set_requires_grad(netD_S, False)
set_requires_grad(netD_T, False)
set_requires_grad(siamese_net, False)
# GAN loss D_T(G_S2T(S))
loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
# GAN loss D_S(G_T2S(B))
loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
# Cycle loss || G_T2S(G_S2T(S)) - S||
loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle
# Cycle loss || G_S2T(G_T2S(T)) - T||
loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle
# Identity loss
# G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
identity_T = netG_S2T(real_T)
loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity
# G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
identity_S = netG_T2S(real_S)
loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity
# siamese network output
f_real_S = siamese_net(real_S)
f_fake_T = siamese_net(fake_T)
f_real_T = siamese_net(real_T)
f_fake_S = siamese_net(fake_S)
# positive pair
loss_contrastive_p_G = criterion_contrastive(f_real_S, f_fake_T, 0) + \
criterion_contrastive(f_real_T, f_fake_S, 0)
# negative pair
loss_contrastive_n_G = criterion_contrastive(f_fake_T, f_real_T, 1) + \
criterion_contrastive(f_fake_S, f_real_S, 1) + \
criterion_contrastive(f_real_S, f_real_T, 1)
# contrastive loss
loss_contrastive_G = (loss_contrastive_p_G + 0.5 * loss_contrastive_n_G) / 4 * args.trade_off_contrastive
# combined loss and calculate gradients
loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
if epoch > 1:
loss_G += loss_contrastive_G
netG_S2T.zero_grad()
netG_T2S.zero_grad()
loss_G.backward()
optimizer_G.step()
# update corresponding statistics
losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
if epoch > 1:
losses_contrastive_G.update(loss_contrastive_G, real_S.size(0))
# ===============================================
# train the siamese network (when epoch > 0)
# ===============================================
if epoch > 0:
set_requires_grad(siamese_net, True)
# siamese network output
f_real_S = siamese_net(real_S)
f_fake_T = siamese_net(fake_T.detach())
f_real_T = siamese_net(real_T)
f_fake_S = siamese_net(fake_S.detach())
# positive pair
loss_contrastive_p_siamese = criterion_contrastive(f_real_S, f_fake_T, 0) + \
criterion_contrastive(f_real_T, f_fake_S, 0)
# negative pair
loss_contrastive_n_siamese = criterion_contrastive(f_real_S, f_real_T, 1)
# contrastive loss
loss_contrastive_siamese = (loss_contrastive_p_siamese + 2 * loss_contrastive_n_siamese) / 3
# update siamese network
siamese_net.zero_grad()
loss_contrastive_siamese.backward()
optimizer_siamese.step()
# update corresponding statistics
losses_contrastive_siamese.update(loss_contrastive_siamese, real_S.size(0))
# ===============================================
# train the discriminators
# ===============================================
set_requires_grad(netD_S, True)
set_requires_grad(netD_T, True)
# Calculate GAN loss for discriminator D_S
fake_S_ = fake_S_pool.query(fake_S.detach())
loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False))
# Calculate GAN loss for discriminator D_T
fake_T_ = fake_T_pool.query(fake_T.detach())
loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False))
# update discriminators
netD_S.zero_grad()
netD_T.zero_grad()
loss_D_S.backward()
loss_D_T.backward()
optimizer_D.step()
# update corresponding statistics
losses_D_S.update(loss_D_S.item(), real_S.size(0))
losses_D_T.update(loss_D_T.item(), real_S.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T],
["real_S", "real_T", "fake_S", "fake_T", "rec_S",
"rec_T", "identity_S", "identity_T"]):
visualize(tensor[0], "{}_{}".format(i, name))
if __name__ == '__main__':
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='SPGAN for Domain Adaptative ReID')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-s', '--source', type=str, help='source domain')
parser.add_argument('-t', '--target', type=str, help='target domain')
parser.add_argument('--load-size', nargs='+', type=int, default=(286, 144), help='loading image size')
parser.add_argument('--input-size', nargs='+', type=int, default=(256, 128),
help='the input and output image size during training')
# model parameters
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--nsf', type=int, default=64, help='# of sianet filters int the first conv layer')
parser.add_argument('--netD', type=str, default='patch',
help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.')
parser.add_argument('--netG', type=str, default='resnet_9',
help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]')
parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization [instance | batch | none]')
# training parameters
parser.add_argument("--resume", type=str, default=None,
help="Where restore model parameters from.")
parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss')
parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss')
parser.add_argument('--trade-off-contrastive', type=float, default=2.0, help='trade off for contrastive loss')
parser.add_argument('--margin', type=float, default=2,
help='margin for contrastive loss')
parser.add_argument('-b', '--batch-size', default=8, type=int,
metavar='N', help='mini-batch size (default: 8)')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=15, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--epochs-decay', type=int, default=15,
help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('-i', '--iters-per-epoch', default=2000, type=int,
help='Number of iterations per epoch')
parser.add_argument('--pool-size', type=int, default=50,
help='the size of image buffer that stores previously generated images')
parser.add_argument('-p', '--print-freq', default=500, type=int,
metavar='N', help='print frequency (default: 500)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='spgan',
help="Where to save logs, checkpoints and debugging images.")
# test parameters
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--translated-root', type=str, default=None,
help="The root to put the translated dataset")
parser.add_argument('--test-input-size', nargs='+', type=int, default=(256, 128),
help='the input image size during testing')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/re_identification/spgan.sh
================================================
# Market1501 -> Duke
# step1: train SPGAN
CUDA_VISIBLE_DEVICES=0 python spgan.py data -s Market1501 -t DukeMTMC \
--log logs/spgan/Market2Duke --translated-root data/spganM2D --seed 0
# step2: train baseline on translated source dataset
CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganM2D data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Market2Duke
# Duke -> Market1501
# step1: train SPGAN
CUDA_VISIBLE_DEVICES=0 python spgan.py data -s DukeMTMC -t Market1501 \
--log logs/spgan/Duke2Market --translated-root data/spganD2M --seed 0
# step2: train baseline on translated source dataset
CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganD2M data -s DukeMTMC -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Duke2Market
# Market1501 -> MSMT17
# step1: train SPGAN
CUDA_VISIBLE_DEVICES=0 python spgan.py data -s Market1501 -t MSMT17 \
--log logs/spgan/Market2MSMT --translated-root data/spganM2S --seed 0
# step2: train baseline on translated source dataset
CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganM2S data -s Market1501 -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Market2MSMT
# MSMT -> Market1501
# step1: train SPGAN
CUDA_VISIBLE_DEVICES=0 python spgan.py data -s MSMT17 -t Market1501 \
--log logs/spgan/MSMT2Market --translated-root data/spganS2M --seed 0
# step2: train baseline on translated source dataset
CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganS2M data -s MSMT17 -t Market1501 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/MSMT2Market
# Duke -> MSMT
# step1: train SPGAN
CUDA_VISIBLE_DEVICES=0 python spgan.py data -s DukeMTMC -t MSMT17 \
--log logs/spgan/Duke2MSMT --translated-root data/spganD2S --seed 0
# step2: train baseline on translated source dataset
CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganD2S data -s DukeMTMC -t MSMT17 -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Duke2MSMT
# MSMT -> Duke
# step1: train SPGAN
CUDA_VISIBLE_DEVICES=0 python spgan.py data -s MSMT17 -t DukeMTMC \
--log logs/spgan/MSMT2Duke --translated-root data/spganS2D --seed 0
# step2: train baseline on translated source dataset
CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganS2D data -s MSMT17 -t DukeMTMC -a reid_resnet50 \
--iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/MSMT2Duke
================================================
FILE: examples/domain_adaptation/re_identification/utils.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import sys
import timm
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Parameter
import torchvision.transforms as T
sys.path.append('../../..')
from tllib.utils.metric.reid import extract_reid_feature
from tllib.utils.analysis import tsne
from tllib.vision.transforms import RandomErasing
import tllib.vision.models.reid as models
import tllib.normalization.ibn as ibn_models
def copy_state_dict(model, state_dict, strip=None):
"""Copy state dict into the passed in ReID model. As we are using classification loss, which means we need to output
different number of classes(identities) for different datasets, we will not copy the parameters of last `fc` layer.
"""
tgt_state = model.state_dict()
copied_names = set()
for name, param in state_dict.items():
if strip is not None and name.startswith(strip):
name = name[len(strip):]
if name not in tgt_state:
continue
if isinstance(param, Parameter):
param = param.data
if param.size() != tgt_state[name].size():
print('mismatch:', name, param.size(), tgt_state[name].size())
continue
tgt_state[name].copy_(param)
copied_names.add(name)
missing = set(tgt_state.keys()) - copied_names
if len(missing) > 0:
print("missing keys in state_dict:", missing)
return model
def get_model_names():
return sorted(name for name in models.__dict__ if
name.islower() and not name.startswith("__") and callable(models.__dict__[name])) + \
sorted(name for name in ibn_models.__dict__ if
name.islower() and not name.startswith("__") and callable(ibn_models.__dict__[name])) + \
timm.list_models()
def get_model(model_name):
if model_name in models.__dict__:
# load models from tllib.vision.models
backbone = models.__dict__[model_name](pretrained=True)
elif model_name in ibn_models.__dict__:
# load models (with ibn) from tllib.normalization.ibn
backbone = ibn_models.__dict__[model_name](pretrained=True)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=True)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
return backbone
def get_train_transform(height, width, resizing='default', random_horizontal_flip=True, random_color_jitter=False,
random_gray_scale=False, random_erasing=False):
"""
resizing mode:
- default: resize the image to (height, width), zero-pad it by 10 on each size, the take a random crop of
(height, width)
- res: resize the image to(height, width)
"""
if resizing == 'default':
transform = T.Compose([
T.Resize((height, width), interpolation=3),
T.Pad(10),
T.RandomCrop((height, width))
])
elif resizing == 'res':
transform = T.Resize((height, width), interpolation=3)
else:
raise NotImplementedError(resizing)
transforms = [transform]
if random_horizontal_flip:
transforms.append(T.RandomHorizontalFlip())
if random_color_jitter:
transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3))
if random_gray_scale:
transforms.append(T.RandomGrayscale())
transforms.extend([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
if random_erasing:
transforms.append(RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406]))
return T.Compose(transforms)
def get_val_transform(height, width):
return T.Compose([
T.Resize((height, width), interpolation=3),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def visualize_tsne(source_loader, target_loader, model, filename, device, n_data_points_per_domain=3000):
"""Visualize features from different domains using t-SNE. As we can have very large number of samples in each
domain, only `n_data_points_per_domain` number of samples are randomly selected in each domain.
"""
source_feature_dict = extract_reid_feature(source_loader, model, device, normalize=True)
source_feature = torch.stack(list(source_feature_dict.values())).cpu()
source_feature = source_feature[torch.randperm(len(source_feature))]
source_feature = source_feature[:n_data_points_per_domain]
target_feature_dict = extract_reid_feature(target_loader, model, device, normalize=True)
target_feature = torch.stack(list(target_feature_dict.values())).cpu()
target_feature = target_feature[torch.randperm(len(target_feature))]
target_feature = target_feature[:n_data_points_per_domain]
tsne.visualize(source_feature, target_feature, filename, source_color='cornflowerblue', target_color='darkorange')
print('T-SNE process is done, figure is saved to {}'.format(filename))
def k_reciprocal_neigh(initial_rank, i, k1):
"""Compute k-reciprocal neighbors of i-th sample. Two samples f_i, f_j are k reciprocal-neighbors if and only if
each one of them is among the k-nearest samples of another sample.
"""
forward_k_neigh_index = initial_rank[i, :k1 + 1]
backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
fi = torch.nonzero(backward_k_neigh_index == i)[:, 0]
return forward_k_neigh_index[fi]
def compute_rerank_dist(target_features, k1=30, k2=6):
"""Compute distance according to `Re-ranking Person Re-identification with k-reciprocal Encoding
(CVPR 2017) `_.
"""
n = target_features.size(0)
original_dist = torch.pow(target_features, 2).sum(dim=1, keepdim=True) * 2
original_dist = original_dist.expand(n, n) - 2 * torch.mm(target_features, target_features.t())
original_dist /= original_dist.max(0)[0]
original_dist = original_dist.t()
initial_rank = torch.argsort(original_dist, dim=-1)
all_num = gallery_num = original_dist.size(0)
del target_features
nn_k1 = []
nn_k1_half = []
for i in range(all_num):
nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1))
nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1 / 2))))
V = torch.zeros(all_num, all_num)
for i in range(all_num):
k_reciprocal_index = nn_k1[i]
k_reciprocal_expansion_index = k_reciprocal_index
for candidate in k_reciprocal_index:
candidate_k_reciprocal_index = nn_k1_half[candidate]
if (len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len(
candidate_k_reciprocal_index)):
k_reciprocal_expansion_index = torch.cat((k_reciprocal_expansion_index, candidate_k_reciprocal_index))
k_reciprocal_expansion_index = torch.unique(k_reciprocal_expansion_index)
weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index])
V[i, k_reciprocal_expansion_index] = weight / torch.sum(weight)
if k2 != 1:
k2_rank = initial_rank[:, :k2].clone().view(-1)
V_qe = V[k2_rank]
V_qe = V_qe.view(initial_rank.size(0), k2, -1).sum(1)
V_qe /= k2
V = V_qe
del V_qe
del initial_rank
invIndex = []
for i in range(gallery_num):
invIndex.append(torch.nonzero(V[:, i])[:, 0])
jaccard_dist = torch.zeros_like(original_dist)
for i in range(all_num):
temp_min = torch.zeros(1, gallery_num)
indNonZero = torch.nonzero(V[i, :])[:, 0]
indImages = [invIndex[ind] for ind in indNonZero]
for j in range(len(indNonZero)):
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + \
torch.min(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]])
jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
del invIndex
del V
pos_bool = (jaccard_dist < 0)
jaccard_dist[pos_bool] = 0.0
return jaccard_dist
================================================
FILE: examples/domain_adaptation/semantic_segmentation/README.md
================================================
# Unsupervised Domain Adaptation for Semantic Segmentation
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
## Dataset
You need to prepare following datasets manually if you want to use them:
- [Cityscapes](https://www.cityscapes-dataset.com/)
- [GTA5](https://download.visinf.tu-darmstadt.de/data/from_games/)
- [Synthia](https://synthia-dataset.net/)
#### Cityscapes, Foggy Cityscapes
- Download Cityscapes and Foggy Cityscapes dataset from the [link](https://www.cityscapes-dataset.com/downloads/). Particularly, we use *leftImg8bit_trainvaltest.zip* for Cityscapes and *leftImg8bit_trainvaltest_foggy.zip* for Foggy Cityscapes.
- Unzip them under the directory like
```
data/Cityscapes
├── gtFine
├── leftImg8bit
│ ├── train
│ ├── val
│ └── test
├── leftImg8bit_foggy
│ ├── train
│ ├── val
│ └── test
└── ...
```
#### GTA-5
You need to download GTA5 manually from [GTA5](https://download.visinf.tu-darmstadt.de/data/from_games/).
Ensure that there exist following directories before you use this dataset.
```
data/GTA5
├── images
├── labels
└── ...
```
#### Synthia
You need to download Synthia manually from [Synthia](https://synthia-dataset.net/).
Ensure that there exist following directories before you use this dataset.
```
data/synthia
├── RGB
├── synthia_mapped_to_cityscapes
└── ...
```
## Supported Methods
Supported methods include:
- [Cycle-Consistent Adversarial Networks (CycleGAN)](https://arxiv.org/pdf/1703.10593.pdf)
- [CyCADA: Cycle-Consistent Adversarial Domain Adaptation](https://arxiv.org/abs/1711.03213)
- [Adversarial Entropy Minimization (ADVENT)](https://arxiv.org/abs/1811.12833)
- [Fourier Domain Adaptation (FDA)](https://arxiv.org/abs/2004.05498)
## Experiment and Results
**Notations**
- ``Origin`` means the accuracy reported by the original paper.
- ``mIoU`` is the accuracy reported by `TLlib`.
- ``ERM`` refers to the model trained with data from the source domain.
- ``Oracle`` refers to the model trained with data from the target domain.
### GTA5->Cityscapes mIoU on deeplabv2 (ResNet-101)
| GTA5 | Origin | mIoU | road | sidewalk | building | wall | fence | pole | traffic light | traffic sign | vegetation | terrian | sky | person | rider | car | truck | bus | train | motorbike | bicycle |
|-------------|--------|------|------|----------|----------|------|-------|------|---------------|--------------|------------|---------|------|--------|-------|------|-------|------|-------|-----------|---------|
| ERM | 27.1 | 37.3 | 66.5 | 17.4 | 73.3 | 13.4 | 21.5 | 22.8 | 30.1 | 17.1 | 82.2 | 7.1 | 73.6 | 57.4 | 28.4 | 78.6 | 36.1 | 13.4 | 1.5 | 31.9 | 36.2 |
| AdvEnt | 43.8 | 43.8 | 89.3 | 33.9 | 80.3 | 24.0 | 25.2 | 27.8 | 36.7 | 18.2 | 84.3 | 33.9 | 81.3 | 59.8 | 28.4 | 84.3 | 34.1 | 44.4 | 0.1 | 33.2 | 12.9 |
| FDA | 44.6 | 45.6 | 85.5 | 31.7 | 81.8 | 27.1 | 24.9 | 28.9 | 38.1 | 23.2 | 83.7 | 40.3 | 80.6 | 60.5 | 30.3 | 79.1 | 32.8 | 45.1 | 5.0 | 32.4 | 35.2 |
| Cycada | 42.7 | 47.4 | 87.3 | 35.7 | 83.7 | 31.3 | 24.0 | 32.2 | 35.8 | 30.3 | 82.7 | 32.0 | 85.7 | 60.8 | 31.5 | 85.6 | 39.8 | 43.3 | 5.4 | 29.5 | 44.6 |
| CycleGAN | | 47.0 | 88.4 | 41.9 | 83.6 | 34.4 | 23.9 | 32.9 | 35.5 | 26.0 | 83.1 | 36.8 | 82.3 | 59.9 | 27.0 | 83.4 | 31.6 | 42.3 | 11.0 | 28.2 | 40.5 |
| Oracle | 65.1 | 70.5 | 97.4 | 79.7 | 90.1 | 53.0 | 50.0 | 48.0 | 55.5 | 67.2 | 90.2 | 60.0 | 93.0 | 72.7 | 55.2 | 92.7 | 76.5 | 78.5 | 56.0 | 54.6 | 68.8 |
### Synthia->Cityscapes mIoU on deeplabv2 (ResNet-101)
| Synthia | Origin | mIoU | road | sidewalk | building | traffic light | traffic sign | vegetation | sky | person | rider | car | bus | motorbike | bicycle |
|-------------|--------|------|------|----------|----------|---------------|--------------|------------|------|--------|-------|------|------|-----------|---------|
| ERM | 22.1 | 41.5 | 59.6 | 21.1 | 77.4 | 7.7 | 17.6 | 78.0 | 84.5 | 53.2 | 16.9 | 65.9 | 24.9 | 8.5 | 24.8 |
| AdvEnt | 47.6 | 47.9 | 88.3 | 44.9 | 80.5 | 4.5 | 9.1 | 81.3 | 86.2 | 52.9 | 21.0 | 82.0 | 30.3 | 11.9 | 30.2 |
| FDA | - | 43.9 | 62.5 | 23.7 | 78.5 | 9.4 | 15.7 | 78.3 | 81.1 | 52.3 | 18.7 | 79.8 | 32.5 | 8.7 | 29.6 |
| Oracle | 71.7 | 76.6 | 97.4 | 79.7 | 90.1 | 55.5 | 67.2 | 90.2 | 93.0 | 72.7 | 55.2 | 92.7 | 78.5 | 54.6 | 68.8 |
### Cityscapes->Foggy Cityscapes mIoU on deeplabv2 (ResNet-101)
| Foggy | Origin | mIoU | road | sidewalk | building | wall | fence | pole | traffic light | traffic sign | vegetation | terrian | sky | person | rider | car | truck | bus | train | motorbike | bicycle |
|-------------|--------|------|------|----------|----------|------|-------|------|---------------|--------------|------------|---------|------|--------|-------|------|-------|------|-------|-----------|---------|
| ERM | | 51.2 | 95.3 | 70.2 | 64.1 | 31.9 | 35.2 | 30.7 | 33.3 | 51.1 | 42.3 | 44.0 | 32.1 | 64.4 | 47.0 | 86.0 | 64.4 | 56.4 | 21.1 | 43.1 | 60.8 |
| AdvEnt | | 61.8 | 96.8 | 75.1 | 76.4 | 46.2 | 42.6 | 39.3 | 43.6 | 58.9 | 74.3 | 50.1 | 75.9 | 67.3 | 51.0 | 89.4 | 70.5 | 64.7 | 39.9 | 47.9 | 65.0 |
| FDA | | 61.9 | 96.9 | 77.2 | 75.3 | 46.5 | 42.0 | 39.8 | 47.1 | 61.0 | 72.7 | 54.6 | 63.8 | 68.4 | 50.1 | 90.1 | 72.8 | 68.0 | 35.5 | 50.8 | 64.2 |
| Cycada | | 63.3 | 96.8 | 75.5 | 79.1 | 38.0 | 40.3 | 42.1 | 48.2 | 61.2 | 76.9 | 52.1 | 77.6 | 68.6 | 51.7 | 90.4 | 71.7 | 70.4 | 43.3 | 52.6 | 65.7 |
| CycleGAN | | 66.0 | 97.1 | 77.6 | 84.3 | 42.7 | 46.3 | 42.8 | 47.5 | 61.0 | 84.0 | 55.2 | 83.4 | 69.4 | 51.8 | 90.7 | 73.7 | 76.2 | 54.2 | 50.7 | 65.6 |
| Oracle | | 66.9 | 97.4 | 78.6 | 88.1 | 50.7 | 50.5 | 46.2 | 51.3 | 64.4 | 88.1 | 55.3 | 87.4 | 70.9 | 52.7 | 91.6 | 72.4 | 73.2 | 31.8 | 52.2 | 67.4 |
## Visualization
If you want to visualize the segmentation results during training, you should set ``--debug``.
```
CUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes --log logs/src_only/gtav2cityscapes --debug
```
Then you can find images, predictions and labels in directory ``logs/src_only/gtav2cityscapes/visualize/``.
Translation model such as CycleGAN will save images by default. Here is the source-style images and its translated version.
## TODO
Support methods: AdaptSeg
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{CycleGAN,
title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},
author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
booktitle={ICCV},
year={2017}
}
@inproceedings{cycada,
title={Cycada: Cycle-consistent adversarial domain adaptation},
author={Hoffman, Judy and Tzeng, Eric and Park, Taesung and Zhu, Jun-Yan and Isola, Phillip and Saenko, Kate and Efros, Alexei and Darrell, Trevor},
booktitle={ICML},
year={2018},
}
@inproceedings{Advent,
author = {Vu, Tuan-Hung and Jain, Himalaya and Bucher, Maxime and Cord, Matthieu and Perez, Patrick},
title = {ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation},
booktitle = {CVPR},
year = {2019}
}
@inproceedings{FDA,
author = {Yanchao Yang and
Stefano Soatto},
title = {{FDA:} Fourier Domain Adaptation for Semantic Segmentation},
booktitle = {CVPR},
year = {2020}
}
```
================================================
FILE: examples/domain_adaptation/semantic_segmentation/advent.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
from PIL import Image
import numpy as np
import shutil
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD, Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
sys.path.append('../../..')
from tllib.alignment.advent import Discriminator, DomainAdversarialEntropyLoss
import tllib.vision.models.segmentation as models
import tllib.vision.datasets.segmentation as datasets
import tllib.vision.transforms.segmentation as T
from tllib.vision.transforms import DeNormalizeAndTranspose
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter, Meter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
source_dataset = datasets.__dict__[args.source]
train_source_dataset = source_dataset(
root=args.source_root,
transforms=T.Compose([
T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)),
T.ColorJitter(brightness=0.3, contrast=0.3),
T.RandomHorizontalFlip(),
T.NormalizeAndTranspose(),
]),
)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
target_dataset = datasets.__dict__[args.target]
train_target_dataset = target_dataset(
root=args.target_root,
transforms=T.Compose([
T.RandomResizedCrop(size=args.train_size, ratio=(2., 2.), scale=(0.5, 1.)),
T.RandomHorizontalFlip(),
T.NormalizeAndTranspose(),
]),
)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
val_target_dataset = target_dataset(
root=args.target_root, split='val',
transforms=T.Compose([
T.Resize(image_size=args.test_input_size, label_size=args.test_output_size),
T.NormalizeAndTranspose(),
]),
)
val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
num_classes = train_source_dataset.num_classes
model = models.__dict__[args.arch](num_classes=num_classes).to(device)
discriminator = Discriminator(num_classes=num_classes).to(device)
# define optimizer and lr scheduler
optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
optimizer_d = Adam(discriminator.parameters(), lr=args.lr_d, betas=(0.9, 0.99))
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power))
lr_scheduler_d = LambdaLR(optimizer_d, lambda x: (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power))
# optionally resume from a checkpoint
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
discriminator.load_state_dict(checkpoint['discriminator'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
optimizer_d.load_state_dict(checkpoint['optimizer_d'])
lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d'])
args.start_epoch = checkpoint['epoch'] + 1
# define loss function (criterion)
criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device)
dann = DomainAdversarialEntropyLoss(discriminator)
interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True)
interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True)
# define visualization function
decode = train_source_dataset.decode_target
def visualize(image, pred, label, prefix):
"""
Args:
image (tensor): 3 x H x W
pred (tensor): C x H x W
label (tensor): H x W
prefix: prefix of the saving image
"""
image = image.detach().cpu().numpy()
pred = pred.detach().max(dim=0)[1].cpu().numpy()
label = label.cpu().numpy()
for tensor, name in [
(Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"),
(decode(label), "label"),
(decode(pred), "pred")
]:
tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name)))
if args.phase == 'test':
confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args)
print(confmat)
return
# start training
best_iou = 0.
for epoch in range(args.start_epoch, args.epochs):
logger.set_epoch(epoch)
print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, model, interp_train, criterion, dann, optimizer,
lr_scheduler, optimizer_d, lr_scheduler_d, epoch, visualize if args.debug else None, args)
# evaluate on validation set
confmat = validate(val_target_loader, model, interp_val, criterion, None, args)
print(confmat.format(train_source_dataset.classes))
acc_global, acc, iu = confmat.compute()
# calculate the mean iou over partial classes
indexes = [train_source_dataset.classes.index(name) for name
in train_source_dataset.evaluate_classes]
iu = iu[indexes]
mean_iou = iu.mean()
# remember best acc@1 and save checkpoint
torch.save(
{
'model': model.state_dict(),
'discriminator': discriminator.state_dict(),
'optimizer': optimizer.state_dict(),
'optimizer_d': optimizer_d.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'lr_scheduler_d': lr_scheduler_d.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if mean_iou > best_iou:
shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))
best_iou = max(best_iou, mean_iou)
print("Target: {} Best: {}".format(mean_iou, best_iou))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model, interp, criterion, dann,
optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR,
epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
losses_transfer = AverageMeter('Loss (transfer)', ':3.2f')
losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f')
accuracies_s = Meter('Acc (s)', ':3.2f')
accuracies_t = Meter('Acc (t)', ':3.2f')
iou_s = Meter('IoU (s)', ':3.2f')
iou_t = Meter('IoU (t)', ':3.2f')
confmat_s = ConfusionMatrix(model.num_classes)
confmat_t = ConfusionMatrix(model.num_classes)
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_s, losses_transfer, losses_discriminator,
accuracies_s, accuracies_t, iou_s, iou_t],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_s, label_s = next(train_source_iter)
x_t, label_t = next(train_target_iter)
x_s = x_s.to(device)
label_s = label_s.long().to(device)
x_t = x_t.to(device)
label_t = label_t.long().to(device)
# measure data loading time
data_time.update(time.time() - end)
optimizer.zero_grad()
optimizer_d.zero_grad()
# Step 1: Train the segmentation network, freeze the discriminator
dann.eval()
y_s = model(x_s)
pred_s = interp(y_s)
loss_cls_s = criterion(pred_s, label_s)
loss_cls_s.backward()
# adversarial training to fool the discriminator
y_t = model(x_t)
pred_t = interp(y_t)
loss_transfer = dann(pred_t, 'source')
(loss_transfer * args.trade_off).backward()
# Step 2: Train the discriminator
dann.train()
loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') + dann(pred_t.detach(), 'target'))
loss_discriminator.backward()
# compute gradient and do SGD step
optimizer.step()
optimizer_d.step()
lr_scheduler.step()
lr_scheduler_d.step()
# measure accuracy and record loss
losses_s.update(loss_cls_s.item(), x_s.size(0))
losses_transfer.update(loss_transfer.item(), x_s.size(0))
losses_discriminator.update(loss_discriminator.item(), x_s.size(0))
confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
acc_global_s, acc_s, iu_s = confmat_s.compute()
acc_global_t, acc_t, iu_t = confmat_t.compute()
accuracies_s.update(acc_s.mean().item())
accuracies_t.update(acc_t.mean().item())
iou_s.update(iu_s.mean().item())
iou_t.update(iu_t.mean().item())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
acc = Meter('Acc', ':3.2f')
iou = Meter('IoU', ':3.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, acc, iou],
prefix='Test: ')
# switch to evaluate mode
model.eval()
confmat = ConfusionMatrix(model.num_classes)
with torch.no_grad():
end = time.time()
for i, (x, label) in enumerate(val_loader):
x = x.to(device)
label = label.long().to(device)
# compute output
output = interp(model(x))
loss = criterion(output, label)
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
confmat.update(label.flatten(), output.argmax(1).flatten())
acc_global, accs, iu = confmat.compute()
acc.update(accs.mean().item())
iou.update(iu.mean().item())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x[0], output[0], label[0], "val_{}".format(i))
return confmat
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='ADVENT for Segmentation Domain Adaptation')
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),
help='the resize ratio for the random resize crop')
parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512),
help='the input and output image size during training')
parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),
help='the input image size during test')
parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024),
help='the output image size during test')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: deeplabv2_resnet101)')
parser.add_argument("--resume", type=str, default=None,
help="Where restore model parameters from.")
parser.add_argument('--trade-off', type=float, default=0.001,
help='trade-off parameter for the advent loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=2, type=int,
metavar='N',
help='mini-batch size (default: 2)')
parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum component of the optimiser.")
parser.add_argument("--weight-decay", type=float, default=0.0005, help="Regularisation parameter for L2-loss.")
parser.add_argument("--lr-power", type=float, default=0.9,
help="Decay parameter to compute the learning rate (only for deeplab).")
parser.add_argument("--lr-d", default=1e-4, type=float,
metavar='LR', help='initial learning rate for discriminator')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--ignore-label", type=int, default=255,
help="The index of the label to ignore during the training.")
parser.add_argument("--log", type=str, default='advent',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--debug', action="store_true",
help='In the debug mode, save images and predictions during training')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/semantic_segmentation/advent.sh
================================================
# GTA5 to Cityscapes
CUDA_VISIBLE_DEVICES=0 python advent.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \
--log logs/advent/gtav2cityscapes
# Synthia to Cityscapes
CUDA_VISIBLE_DEVICES=0 python advent.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \
--log logs/advent/synthia2cityscapes
# Cityscapes to Foggy
CUDA_VISIBLE_DEVICES=0 python advent.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \
--log logs/advent/cityscapes2foggy
================================================
FILE: examples/domain_adaptation/semantic_segmentation/cycada.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import itertools
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage, Compose, Lambda
sys.path.append('../../..')
import tllib.translation.cyclegan as cyclegan
from tllib.translation.cyclegan.util import ImagePool, set_requires_grad
from tllib.translation.cycada import SemanticConsistency
import tllib.vision.models.segmentation as models
import tllib.vision.datasets.segmentation as datasets
from tllib.vision.transforms import Denormalize, NormalizeAndTranspose
import tllib.vision.transforms.segmentation as T
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = T.Compose([
T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
source_dataset = datasets.__dict__[args.source]
train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
target_dataset = datasets.__dict__[args.target]
train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# define networks (both generators and discriminators)
netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)
netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)
# create image buffer to store previously generated images
fake_S_pool = ImagePool(args.pool_size)
fake_T_pool = ImagePool(args.pool_size)
# define optimizer and lr scheduler
optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999))
optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999))
lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay)
lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)
lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)
# optionally resume from a checkpoint
if args.resume:
print("Resume from", args.resume)
checkpoint = torch.load(args.resume, map_location='cpu')
netG_S2T.load_state_dict(checkpoint['netG_S2T'])
netG_T2S.load_state_dict(checkpoint['netG_T2S'])
netD_S.load_state_dict(checkpoint['netD_S'])
netD_T.load_state_dict(checkpoint['netD_T'])
optimizer_G.load_state_dict(checkpoint['optimizer_G'])
optimizer_D.load_state_dict(checkpoint['optimizer_D'])
lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])
args.start_epoch = checkpoint['epoch'] + 1
if args.phase == 'test':
transform = T.Compose([
T.Resize(image_size=args.test_input_size),
T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
])
train_source_dataset.translate(transform, args.translated_root)
return
# define loss function
criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
criterion_semantic = SemanticConsistency(ignore_index=[args.ignore_label]+train_source_dataset.ignore_classes).to(device)
interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True)
# define segmentation model and predict function
model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device)
if args.pretrain:
print("Loading pretrain segmentation model from", args.pretrain)
checkpoint = torch.load(args.pretrain, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()
cycle_gan_tensor_to_segmentation_tensor = Compose([
Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
Lambda(lambda image: image.mul(255).permute((1, 2, 0))),
NormalizeAndTranspose(),
])
def predict(image):
image = cycle_gan_tensor_to_segmentation_tensor(image.squeeze())
image = image.unsqueeze(dim=0).to(device)
prediction = model(image)
return interp_train(prediction)
# define visualization function
tensor_to_image = Compose([
Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ToPILImage()
])
decode = train_source_dataset.decode_target
def visualize(image, name, pred=None):
"""
Args:
image (tensor): image in shape 3 x H x W
name: name of the saving image
pred (tensor): predictions in shape C x H x W
"""
tensor_to_image(image).save(logger.get_image_path("{}.png".format(name)))
if pred is not None:
pred = pred.detach().max(dim=0).indices.cpu().numpy()
pred = decode(pred)
pred.save(logger.get_image_path("pred_{}.png".format(name)))
# start training
for epoch in range(args.start_epoch, args.epochs+args.epochs_decay):
logger.set_epoch(epoch)
print(lr_scheduler_G.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, predict,
criterion_gan, criterion_cycle, criterion_identity, criterion_semantic, optimizer_G, optimizer_D,
fake_S_pool, fake_T_pool, epoch, visualize, args)
# update learning rates
lr_scheduler_G.step()
lr_scheduler_D.step()
# save checkpoint
torch.save(
{
'netG_S2T': netG_S2T.state_dict(),
'netG_T2S': netG_T2S.state_dict(),
'netD_S': netD_S.state_dict(),
'netD_T': netD_T.state_dict(),
'optimizer_G': optimizer_G.state_dict(),
'optimizer_D': optimizer_D.state_dict(),
'lr_scheduler_G': lr_scheduler_G.state_dict(),
'lr_scheduler_D': lr_scheduler_D.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if args.translated_root is not None:
transform = T.Compose([
T.Resize(image_size=args.test_input_size),
T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
])
train_source_dataset.translate(transform, args.translated_root)
logger.close()
def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, predict,
criterion_gan, criterion_cycle, criterion_identity, criterion_semantic,
optimizer_G, optimizer_D, fake_S_pool, fake_T_pool,
epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
losses_D_S = AverageMeter('D_S', ':3.2f')
losses_D_T = AverageMeter('D_T', ':3.2f')
losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
losses_identity_S = AverageMeter('idt_S', ':3.2f')
losses_identity_T = AverageMeter('idt_T', ':3.2f')
losses_semantic_S2T = AverageMeter('sem_S2T', ':3.2f')
losses_semantic_T2S = AverageMeter('sem_T2S', ':3.2f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T,
losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T,
losses_semantic_S2T, losses_semantic_T2S],
prefix="Epoch: [{}]".format(epoch))
end = time.time()
for i in range(args.iters_per_epoch):
real_S, label_s = next(train_source_iter)
real_T, _ = next(train_target_iter)
real_S = real_S.to(device)
real_T = real_T.to(device)
label_s = label_s.to(device)
# measure data loading time
data_time.update(time.time() - end)
# Compute fake images and reconstruction images.
fake_T = netG_S2T(real_S)
rec_S = netG_T2S(fake_T)
fake_S = netG_T2S(real_T)
rec_T = netG_S2T(fake_S)
# Optimizing generators
# discriminators require no gradients
set_requires_grad(netD_S, False)
set_requires_grad(netD_T, False)
optimizer_G.zero_grad()
# GAN loss D_T(G_S2T(S))
loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
# GAN loss D_S(G_T2S(B))
loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
# Cycle loss || G_T2S(G_S2T(S)) - S||
loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle
# Cycle loss || G_S2T(G_T2S(T)) - T||
loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle
# Identity loss
# G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
identity_T = netG_S2T(real_T)
loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity
# G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
identity_S = netG_T2S(real_S)
loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity
# Semantic loss
pred_fake_T = predict(fake_T)
pred_real_S = predict(real_S)
loss_semantic_S2T = criterion_semantic(pred_fake_T, label_s) * args.trade_off_semantic
pred_fake_S = predict(fake_S)
pred_real_T = predict(real_T)
loss_semantic_T2S = criterion_semantic(pred_fake_S, pred_real_T.max(1).indices) * args.trade_off_semantic
# combined loss and calculate gradients
loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + \
loss_identity_S + loss_identity_T + loss_semantic_S2T + loss_semantic_T2S
loss_G.backward()
optimizer_G.step()
# Optimize discriminator
set_requires_grad(netD_S, True)
set_requires_grad(netD_T, True)
optimizer_D.zero_grad()
# Calculate GAN loss for discriminator D_S
fake_S_ = fake_S_pool.query(fake_S.detach())
loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False))
loss_D_S.backward()
# Calculate GAN loss for discriminator D_T
fake_T_ = fake_T_pool.query(fake_T.detach())
loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False))
loss_D_T.backward()
optimizer_D.step()
# measure elapsed time
losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
losses_D_S.update(loss_D_S.item(), real_S.size(0))
losses_D_T.update(loss_D_T.item(), real_S.size(0))
losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
losses_semantic_S2T.update(loss_semantic_S2T.item(), real_S.size(0))
losses_semantic_T2S.update(loss_semantic_T2S.item(), real_S.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
for image, prediction, name in zip([real_S, real_T, fake_S, fake_T],
[pred_real_S, pred_real_T, pred_fake_S, pred_fake_T],
["real_S", "real_T", "fake_S", "fake_T"]):
visualize(image[0], "{}_{}".format(i, name), prediction[0])
for image, name in zip([rec_S, rec_T, identity_S, identity_T],
["rec_S", "rec_T", "identity_S", "identity_T"]):
visualize(image[0], "{}_{}".format(i, name))
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
# dataset parameters
parser = argparse.ArgumentParser(description='Cycada for Segmentation Domain Adaptation')
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),
help='the resize ratio for the random resize crop')
parser.add_argument('--train-size', nargs='+', type=int, default=(512, 256),
help='the input and output image size during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: deeplabv2_resnet101)')
parser.add_argument('--pretrain', type=str, default=None,
help='pretrain checkpoints for segementation model')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='patch',
help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.')
parser.add_argument('--netG', type=str, default='unet_256',
help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]')
parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument("--resume", type=str, default=None,
help="Where restore cyclegan model parameters from.")
parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss')
parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss')
parser.add_argument('--trade-off-semantic', type=float, default=1.0, help='trade off for semantic loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N',
help='mini-batch size (default: 1)')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--epochs-decay', type=int, default=20,
help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('-i', '--iters-per-epoch', default=5000, type=int,
help='Number of iterations per epoch')
parser.add_argument('--pool-size', type=int, default=50,
help='the size of image buffer that stores previously generated images')
parser.add_argument('-p', '--print-freq', default=400, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--ignore-label", type=int, default=255,
help="The index of the label to ignore during the training.")
parser.add_argument("--log", type=str, default='cycada',
help="Where to save logs, checkpoints and debugging images.")
# test parameters
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--translated-root', type=str, default=None,
help="The root to put the translated dataset")
parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),
help='the input image size during test')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/semantic_segmentation/cycada.sh
================================================
# GTA5 to Cityscapes
# First, train the CycleGAN
CUDA_VISIBLE_DEVICES=0 python cycada.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \
--log logs/cycada/gtav2cityscapes --pretrain logs/src_only/gtav2cityscapes/checkpoints/59.pth \
--translated-root data/GTA52Cityscapes/cycada_39
# Then, train the src_only model on the translated source dataset
CUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA52Cityscapes/cycada_39 data/Cityscapes \
-s GTA5 -t Cityscapes --log logs/cycada_src_only/gtav2cityscapes
## Synthia to Cityscapes
# First, train the Cycada
CUDA_VISIBLE_DEVICES=0 python cycada.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \
--log logs/cycada/synthia2cityscapes --pretrain logs/src_only/synthia2cityscapes/checkpoints/59.pth \
--translated-root data/Synthia2Cityscapes/cycada_39
# Then, train the src_only model on the translated source dataset
CUDA_VISIBLE_DEVICES=0 python source_only.py data/Synthia2Cityscapes/cycada_39 data/Cityscapes \
-s Synthia -t Cityscapes --log logs/cycada_src_only/synthia2cityscapes
# Cityscapes to FoggyCityscapes
# First, train the CycleGAN
CUDA_VISIBLE_DEVICES=0 python cycada.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \
--log logs/cycada/cityscapes2foggy --pretrain logs/src_only/cityscapes2foggy/checkpoints/59.pth \
--translated-root data/Cityscapes2Foggy/cycada_39
# Then, train the src_only model on the translated source dataset
CUDA_VISIBLE_DEVICES=0 python source_only.py data/Cityscapes2Foggy/cycada_39 data/Cityscapes \
-s Cityscapes -t FoggyCityscapes --log logs/cycada_src_only/cityscapes2foggy
================================================
FILE: examples/domain_adaptation/semantic_segmentation/cycle_gan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
import itertools
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage, Compose
sys.path.append('../../..')
import tllib.translation.cyclegan as cyclegan
from tllib.translation.cyclegan.util import ImagePool, set_requires_grad
import tllib.vision.datasets.segmentation as datasets
from tllib.vision.transforms import Denormalize
import tllib.vision.transforms.segmentation as T
from tllib.utils.data import ForeverDataIterator
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = T.Compose([
T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
source_dataset = datasets.__dict__[args.source]
train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
target_dataset = datasets.__dict__[args.target]
train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# define networks (both generators and discriminators)
netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device)
netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)
netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device)
# create image buffer to store previously generated images
fake_S_pool = ImagePool(args.pool_size)
fake_T_pool = ImagePool(args.pool_size)
# define optimizer and lr scheduler
optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999))
optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999))
lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay)
lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function)
lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function)
# optionally resume from a checkpoint
if args.resume:
print("Resume from", args.resume)
checkpoint = torch.load(args.resume, map_location='cpu')
netG_S2T.load_state_dict(checkpoint['netG_S2T'])
netG_T2S.load_state_dict(checkpoint['netG_T2S'])
netD_S.load_state_dict(checkpoint['netD_S'])
netD_T.load_state_dict(checkpoint['netD_T'])
optimizer_G.load_state_dict(checkpoint['optimizer_G'])
optimizer_D.load_state_dict(checkpoint['optimizer_D'])
lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G'])
lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D'])
args.start_epoch = checkpoint['epoch'] + 1
if args.phase == 'test':
transform = T.Compose([
T.Resize(image_size=args.test_input_size),
T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
])
train_source_dataset.translate(transform, args.translated_root)
return
# define loss function
criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()
# define visualization function
tensor_to_image = Compose([
Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
ToPILImage()
])
def visualize(image, name):
"""
Args:
image (tensor): image in shape 3 x H x W
name: name of the saving image
"""
tensor_to_image(image).save(logger.get_image_path("{}.png".format(name)))
# start training
for epoch in range(args.start_epoch, args.epochs+args.epochs_decay):
logger.set_epoch(epoch)
print(lr_scheduler_G.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T,
criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D,
fake_S_pool, fake_T_pool, epoch, visualize, args)
# update learning rates
lr_scheduler_G.step()
lr_scheduler_D.step()
# save checkpoint
torch.save(
{
'netG_S2T': netG_S2T.state_dict(),
'netG_T2S': netG_T2S.state_dict(),
'netD_S': netD_S.state_dict(),
'netD_T': netD_T.state_dict(),
'optimizer_G': optimizer_G.state_dict(),
'optimizer_D': optimizer_D.state_dict(),
'lr_scheduler_G': lr_scheduler_G.state_dict(),
'lr_scheduler_D': lr_scheduler_D.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if args.translated_root is not None:
transform = T.Compose([
T.Resize(image_size=args.test_input_size),
T.wrapper(cyclegan.transform.Translation)(netG_S2T, device),
])
train_source_dataset.translate(transform, args.translated_root)
logger.close()
def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T,
criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D,
fake_S_pool, fake_T_pool, epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_G_S2T = AverageMeter('G_S2T', ':3.2f')
losses_G_T2S = AverageMeter('G_T2S', ':3.2f')
losses_D_S = AverageMeter('D_S', ':3.2f')
losses_D_T = AverageMeter('D_T', ':3.2f')
losses_cycle_S = AverageMeter('cycle_S', ':3.2f')
losses_cycle_T = AverageMeter('cycle_T', ':3.2f')
losses_identity_S = AverageMeter('idt_S', ':3.2f')
losses_identity_T = AverageMeter('idt_T', ':3.2f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T,
losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T],
prefix="Epoch: [{}]".format(epoch))
end = time.time()
for i in range(args.iters_per_epoch):
real_S, _ = next(train_source_iter)
real_T, _ = next(train_target_iter)
real_S = real_S.to(device)
real_T = real_T.to(device)
# measure data loading time
data_time.update(time.time() - end)
# Compute fake images and reconstruction images.
fake_T = netG_S2T(real_S)
rec_S = netG_T2S(fake_T)
fake_S = netG_T2S(real_T)
rec_T = netG_S2T(fake_S)
# Optimizing generators
# discriminators require no gradients
set_requires_grad(netD_S, False)
set_requires_grad(netD_T, False)
optimizer_G.zero_grad()
# GAN loss D_T(G_S2T(S))
loss_G_S2T = criterion_gan(netD_T(fake_T), real=True)
# GAN loss D_S(G_T2S(B))
loss_G_T2S = criterion_gan(netD_S(fake_S), real=True)
# Cycle loss || G_T2S(G_S2T(S)) - S||
loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle
# Cycle loss || G_S2T(G_T2S(T)) - T||
loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle
# Identity loss
# G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T||
identity_T = netG_S2T(real_T)
loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity
# G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S||
identity_S = netG_T2S(real_S)
loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity
# combined loss and calculate gradients
loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T
loss_G.backward()
optimizer_G.step()
# Optimize discriminator
set_requires_grad(netD_S, True)
set_requires_grad(netD_T, True)
optimizer_D.zero_grad()
# Calculate GAN loss for discriminator D_S
fake_S_ = fake_S_pool.query(fake_S.detach())
loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False))
loss_D_S.backward()
# Calculate GAN loss for discriminator D_T
fake_T_ = fake_T_pool.query(fake_T.detach())
loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False))
loss_D_T.backward()
optimizer_D.step()
# measure elapsed time
losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0))
losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0))
losses_D_S.update(loss_D_S.item(), real_S.size(0))
losses_D_T.update(loss_D_T.item(), real_S.size(0))
losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0))
losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0))
losses_identity_S.update(loss_identity_S.item(), real_S.size(0))
losses_identity_T.update(loss_identity_T.item(), real_S.size(0))
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T],
["real_S", "real_T", "fake_S", "fake_T", "rec_S",
"rec_T", "identity_S", "identity_T"]):
visualize(tensor[0], "{}_{}".format(i, name))
if __name__ == '__main__':
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='CycleGAN for Segmentation Domain Adaptation')
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),
help='the resize ratio for the random resize crop')
parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512),
help='the input and output image size during training')
# model parameters
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='patch',
help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.')
parser.add_argument('--netG', type=str, default='unet_256',
help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]')
parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument("--resume", type=str, default=None,
help="Where restore model parameters from.")
parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss')
parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N',
help='mini-batch size (default: 1)')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--epochs-decay', type=int, default=20,
help='number of epochs to linearly decay learning rate to zero')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('-i', '--iters-per-epoch', default=5000, type=int,
help='Number of iterations per epoch')
parser.add_argument('--pool-size', type=int, default=50,
help='the size of image buffer that stores previously generated images')
parser.add_argument('-p', '--print-freq', default=400, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='cyclegan',
help="Where to save logs, checkpoints and debugging images.")
# test parameters
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--translated-root', type=str, default=None,
help="The root to put the translated dataset")
parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),
help='the input image size during test')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/semantic_segmentation/cycle_gan.sh
================================================
# GTA5 to Cityscapes
# First, train the CycleGAN
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \
--log logs/cyclegan/gtav2cityscapes --translated-root data/GTA52Cityscapes/CycleGAN_39
# Then, train the src_only model on the translated source dataset
CUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA52Cityscapes/CycleGAN_39 data/Cityscapes \
-s GTA5 -t Cityscapes --log logs/cyclegan_src_only/gtav2cityscapes
# Cityscapes to FoggyCityscapes
# First, train the CycleGAN
CUDA_VISIBLE_DEVICES=0 python cycle_gan.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \
--log logs/cyclegan/cityscapes2foggy --translated-root data/Cityscapes2Foggy/CycleGAN_39
# Then, train the src_only model on the translated source dataset
CUDA_VISIBLE_DEVICES=0 python source_only.py data/Cityscapes2Foggy/CycleGAN_39 data/Cityscapes \
-s Cityscapes -t FoggyCityscapes --log logs/cyclegan_src_only/cityscapes2foggy
================================================
FILE: examples/domain_adaptation/semantic_segmentation/erm.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
from PIL import Image
import numpy as np
import shutil
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
sys.path.append('../../..')
import tllib.vision.models.segmentation as models
import tllib.vision.datasets.segmentation as datasets
import tllib.vision.transforms.segmentation as T
from tllib.vision.transforms import DeNormalizeAndTranspose
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter, Meter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
source_dataset = datasets.__dict__[args.source]
train_source_dataset = source_dataset(
root=args.source_root,
transforms=T.Compose([
T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=args.resize_scale),
T.ColorJitter(brightness=0.3, contrast=0.3),
T.RandomHorizontalFlip(),
T.NormalizeAndTranspose(),
]),
)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
target_dataset = datasets.__dict__[args.target]
val_target_dataset = target_dataset(
root=args.target_root, split='val',
transforms=T.Compose([
T.Resize(image_size=args.test_input_size, label_size=args.test_output_size),
T.NormalizeAndTranspose(),
]),
)
val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True)
train_source_iter = ForeverDataIterator(train_source_loader)
# create model
model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device)
# define optimizer and lr scheduler
optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch)
** (args.lr_power))
# optionally resume from a checkpoint
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
# define loss function (criterion)
criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device)
interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True)
interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True)
# define visualization function
decode = train_source_dataset.decode_target
def visualize(image, pred, label, prefix):
"""
Args:
image (tensor): 3 x H x W
pred (tensor): C x H x W
label (tensor): H x W
prefix: prefix of the saving image
"""
image = image.detach().cpu().numpy()
pred = pred.detach().max(dim=0)[1].cpu().numpy()
label = label.cpu().numpy()
for tensor, name in [
(Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"),
(decode(label), "label"),
(decode(pred), "pred")
]:
tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name)))
if args.phase == 'test':
confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args)
print(confmat)
return
# start training
best_iou = 0.
for epoch in range(args.start_epoch, args.epochs):
logger.set_epoch(epoch)
print(lr_scheduler.get_lr())
# train for one epoch
train(train_source_iter, model, interp_train, criterion, optimizer,
lr_scheduler, epoch, visualize if args.debug else None, args)
# evaluate on validation set
confmat = validate(val_target_loader, model, interp_val, criterion, visualize if args.debug else None, args)
print(confmat.format(train_source_dataset.classes))
acc_global, acc, iu = confmat.compute()
# calculate the mean iou over partial classes
indexes = [train_source_dataset.classes.index(name) for name
in train_source_dataset.evaluate_classes]
iu = iu[indexes]
mean_iou = iu.mean()
# remember best iou and save checkpoint
torch.save(
{
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if mean_iou > best_iou:
shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))
best_iou = max(best_iou, mean_iou)
print("Target: {} Best: {}".format(mean_iou, best_iou))
logger.close()
def train(train_source_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
accuracies_s = Meter('Acc (s)', ':3.2f')
iou_s = Meter('IoU (s)', ':3.2f')
confmat_s = ConfusionMatrix(model.num_classes)
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_s,
accuracies_s, iou_s],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, label_s = next(train_source_iter)
x_s = x_s.to(device)
label_s = label_s.long().to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s = model(x_s)
pred_s = interp(y_s)
loss_cls_s = criterion(pred_s, label_s)
loss_cls_s.backward()
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure accuracy and record loss
losses_s.update(loss_cls_s.item(), x_s.size(0))
confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
acc_global_s, acc_s, iu_s = confmat_s.compute()
accuracies_s.update(acc_s.mean().item())
iou_s.update(iu_s.mean().item())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
acc = Meter('Acc', ':3.2f')
iou = Meter('IoU', ':3.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, acc, iou],
prefix='Test: ')
# switch to evaluate mode
model.eval()
confmat = ConfusionMatrix(model.num_classes)
with torch.no_grad():
end = time.time()
for i, (x, label) in enumerate(val_loader):
x = x.to(device)
label = label.long().to(device)
# compute output
output = interp(model(x))
loss = criterion(output, label)
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
confmat.update(label.flatten(), output.argmax(1).flatten())
acc_global, accs, iu = confmat.compute()
acc.update(accs.mean().item())
iou.update(iu.mean().item())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x[0], output[0], label[0], "val_{}".format(i))
return confmat
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='Source Only for Segmentation Domain Adaptation')
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),
help='the resize ratio for the random resize crop')
parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.5, 1.),
help='the resize scale for the random resize crop')
parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512),
help='the input and output image size during training')
parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),
help='the input image size during test')
parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024),
help='the output image size during test')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: deeplabv2_resnet101)')
parser.add_argument("--resume", type=str, default=None,
help="Where restore model parameters from.")
# training parameters
parser.add_argument('-b', '--batch-size', default=2, type=int,
metavar='N',
help='mini-batch size (default: 2)')
parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum component of the optimiser.")
parser.add_argument("--weight-decay", type=float, default=0.0005,
help="Regularisation parameter for L2-loss.")
parser.add_argument("--lr-power", type=float, default=0.9,
help="Decay parameter to compute the learning rate (only for deeplab).")
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--ignore-label", type=int, default=255,
help="The index of the label to ignore during the training.")
parser.add_argument("--log", type=str, default='src_only',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--debug', action="store_true",
help='In the debug mode, save images and predictions during training')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/semantic_segmentation/erm.sh
================================================
# Source Only
# GTA5 to Cityscapes
CUDA_VISIBLE_DEVICES=0 python erm.py data/GTA5 data/Cityscapes \
-s GTA5 -t Cityscapes --log logs/erm/gtav2cityscapes
# Synthia to Cityscapes
CUDA_VISIBLE_DEVICES=0 python erm.py data/synthia data/Cityscapes \
-s Synthia -t Cityscapes --log logs/erm/synthia2cityscapes
# Cityscapes to FoggyCityscapes
CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \
-s Cityscapes -t FoggyCityscapes --log logs/erm/cityscapes2foggy
# Oracle
# Oracle Results on Cityscapes
CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \
-s Cityscapes -t Cityscapes --log logs/oracle/cityscapes
# Oracle Results on Foggy Cityscapes
CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \
-s FoggyCityscapes -t FoggyCityscapes --log logs/oracle/foggy_cityscapes
================================================
FILE: examples/domain_adaptation/semantic_segmentation/fda.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import sys
import argparse
from PIL import Image
import numpy as np
import os
import math
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
sys.path.append('../../..')
from tllib.translation.fourier_transform import FourierTransform
import tllib.vision.models.segmentation as models
import tllib.vision.datasets.segmentation as datasets
import tllib.vision.transforms.segmentation as T
from tllib.vision.transforms import DeNormalizeAndTranspose
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter, Meter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def robust_entropy(y, ita=1.5, num_classes=19, reduction='mean'):
""" Robust entropy proposed in `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) `_
Args:
y (tensor): logits output of segmentation model in shape of :math:`(N, C, H, W)`
ita (float, optional): parameters for robust entropy. Default: 1.5
num_classes (int, optional): number of classes. Default: 19
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output. Default: ``'mean'``
Returns:
Scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, )`.
"""
P = F.softmax(y, dim=1)
logP = F.log_softmax(y, dim=1)
PlogP = P * logP
ent = -1.0 * PlogP.sum(dim=1)
ent = ent / math.log(num_classes)
# compute robust entropy
ent = ent ** 2.0 + 1e-8
ent = ent ** ita
if reduction == 'mean':
return ent.mean()
else:
return ent
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
target_dataset = datasets.__dict__[args.target]
train_target_dataset = target_dataset(
root=args.target_root,
transforms=T.Compose([
T.Resize(image_size=args.train_size),
T.NormalizeAndTranspose(),
]),
)
train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
val_target_dataset = target_dataset(
root=args.target_root, split='val',
transforms=T.Compose([
T.Resize(image_size=args.test_input_size, label_size=args.test_output_size),
T.NormalizeAndTranspose(),
])
)
val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True)
# collect the absolute paths of all images in the target dataset
target_image_list = train_target_dataset.collect_image_paths()
# build a fourier transform that translate source images to the target style
fourier_transform = T.wrapper(FourierTransform)(target_image_list, os.path.join(logger.root, "amplitudes"),
rebuild=False, beta=args.beta)
source_dataset = datasets.__dict__[args.source]
train_source_dataset = source_dataset(
root=args.source_root,
transforms=T.Compose([
T.Resize((2048, 1024)), # convert source image to the size of the target image before fourier transform
fourier_transform,
T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)),
T.ColorJitter(brightness=0.3, contrast=0.3),
T.RandomHorizontalFlip(),
T.NormalizeAndTranspose(),
]),
)
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True)
train_source_iter = ForeverDataIterator(train_source_loader)
train_target_iter = ForeverDataIterator(train_target_loader)
# create model
model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device)
# define optimizer and lr scheduler
optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
lr_scheduler = LambdaLR(optimizer,
lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power))
# optionally resume from a checkpoint
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
# define loss function (criterion)
criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device)
interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True)
interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True)
# define visualization function
decode = train_source_dataset.decode_target
def visualize(image, pred, label, prefix):
"""
Args:
image (tensor): 3 x H x W
pred (tensor): C x H x W
label (tensor): H x W
prefix: prefix of the saving image
"""
image = image.detach().cpu().numpy()
pred = pred.detach().max(dim=0)[1].cpu().numpy()
label = label.cpu().numpy()
for tensor, name in [
(Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"),
(decode(label), "label"),
(decode(pred), "pred")
]:
tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name)))
if args.phase == 'test':
confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args)
print(confmat)
return
# start training
best_iou = 0.
for epoch in range(args.start_epoch, args.epochs):
logger.set_epoch(epoch)
print(lr_scheduler.get_lr())
# train for one epoch
train(train_source_iter, train_target_iter, model, interp_train, criterion, optimizer,
lr_scheduler, epoch, visualize if args.debug else None, args)
# evaluate on validation set
confmat = validate(val_target_loader, model, interp_val, criterion, None, args)
print(confmat.format(train_source_dataset.classes))
acc_global, acc, iu = confmat.compute()
# calculate the mean iou over partial classes
indexes = [train_source_dataset.classes.index(name) for name
in train_source_dataset.evaluate_classes]
iu = iu[indexes]
mean_iou = iu.mean()
# remember best acc@1 and save checkpoint
torch.save(
{
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'args': args
}, logger.get_checkpoint_path(epoch)
)
if mean_iou > best_iou:
shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best'))
best_iou = max(best_iou, mean_iou)
print("Target: {} Best: {}".format(mean_iou, best_iou))
logger.close()
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
model, interp, criterion, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
losses_t = AverageMeter('Loss (t)', ':3.2f')
losses_entropy_t = AverageMeter('Entropy (t)', ':3.2f')
accuracies_s = Meter('Acc (s)', ':3.2f')
accuracies_t = Meter('Acc (t)', ':3.2f')
iou_s = Meter('IoU (s)', ':3.2f')
iou_t = Meter('IoU (t)', ':3.2f')
confmat_s = ConfusionMatrix(model.num_classes)
confmat_t = ConfusionMatrix(model.num_classes)
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_s, losses_t, losses_entropy_t,
accuracies_s, accuracies_t, iou_s, iou_t],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
optimizer.zero_grad()
x_s, label_s = next(train_source_iter)
x_t, label_t = next(train_target_iter)
x_s = x_s.to(device)
label_s = label_s.long().to(device)
x_t = x_t.to(device)
label_t = label_t.long().to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s = model(x_s)
pred_s = interp(y_s)
loss_cls_s = criterion(pred_s, label_s)
loss_cls_s.backward()
y_t = model(x_t)
pred_t = interp(y_t)
loss_cls_t = criterion(pred_t, label_t)
loss_entropy_t = robust_entropy(y_t, args.ita)
(args.entropy_weight * loss_entropy_t).backward()
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure accuracy and record loss
losses_s.update(loss_cls_s.item(), x_s.size(0))
losses_t.update(loss_cls_t.item(), x_s.size(0))
losses_entropy_t.update(loss_entropy_t.item(), x_s.size(0))
confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten())
confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten())
acc_global_s, acc_s, iu_s = confmat_s.compute()
acc_global_t, acc_t, iu_t = confmat_t.compute()
accuracies_s.update(acc_s.mean().item())
accuracies_t.update(acc_t.mean().item())
iou_s.update(iu_s.mean().item())
iou_t.update(iu_t.mean().item())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i))
visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i))
def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
acc = Meter('Acc', ':3.2f')
iou = Meter('IoU', ':3.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, acc, iou],
prefix='Test: ')
# switch to evaluate mode
model.eval()
confmat = ConfusionMatrix(model.num_classes)
with torch.no_grad():
end = time.time()
for i, (x, label) in enumerate(val_loader):
x = x.to(device)
label = label.long().to(device)
# compute output
output = interp(model(x))
loss = criterion(output, label)
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
confmat.update(label.flatten(), output.argmax(1).flatten())
acc_global, accs, iu = confmat.compute()
acc.update(accs.mean().item())
iou.update(iu.mean().item())
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(x[0], output[0], label[0], "val_{}".format(i))
return confmat
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description='FDA for Segmentation Domain Adaptation')
# dataset parameters
parser.add_argument('source_root', help='root path of the source dataset')
parser.add_argument('target_root', help='root path of the target dataset')
parser.add_argument('-s', '--source', help='source domain(s)')
parser.add_argument('-t', '--target', help='target domain(s)')
parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.),
help='the resize ratio for the random resize crop')
parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512),
help='the input and output image size during training')
parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512),
help='the input image size during test')
parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024),
help='the output image size during test')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: deeplabv2_resnet101)')
parser.add_argument("--entropy-weight", type=float, default=0., help="weight for entropy")
parser.add_argument("--ita", type=float, default=2.0, help="ita for robust entropy")
parser.add_argument("--beta", type=int, default=1, help="beta for FDA")
parser.add_argument("--resume", type=str, default=None,
help="Where restore model parameters from.")
# training parameters
parser.add_argument('-b', '--batch-size', default=2, type=int,
metavar='N',
help='mini-batch size (default: 2)')
parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument("--momentum", type=float, default=0.9, help="Momentum component of the optimiser.")
parser.add_argument("--weight-decay", type=float, default=0.0005, help="Regularisation parameter for L2-loss.")
parser.add_argument("--lr-power", type=float, default=0.9,
help="Decay parameter to compute the learning rate (only for deeplab).")
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--ignore-label", type=int, default=255,
help="The index of the label to ignore during the training.")
parser.add_argument("--log", type=str, default='fda',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
parser.add_argument('--debug', action="store_true",
help='In the debug mode, save images and predictions during training')
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/semantic_segmentation/fda.sh
================================================
# GTA5 to Cityscapes
CUDA_VISIBLE_DEVICES=0 python fda.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \
--log logs/fda/gtav2cityscapes --debug
# Synthia to Cityscapes
CUDA_VISIBLE_DEVICES=0 python fda.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \
--log logs/fda/synthia2cityscapes --debug
# Cityscapes to FoggyCityscapes
CUDA_VISIBLE_DEVICES=0 python fda.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \
--log logs/fda/cityscapes2foggy --debug
================================================
FILE: examples/domain_adaptation/wilds_image_classification/README.md
================================================
# Unsupervised Domain Adaptation for WILDS (Image Classification)
## Installation
It’s suggested to use **pytorch==1.9.0** in order to reproduce the benchmark results.
You need to install **apex** following ``https://github.com/NVIDIA/apex``. Then run
```
pip install -r requirements.txt
```
## Dataset
Following datasets can be downloaded automatically:
- [DomainNet](http://ai.bu.edu/M3SDA/)
- [iwildcam (WILDS)](https://wilds.stanford.edu/datasets/)
- [camelyon17 (WILDS)](https://wilds.stanford.edu/datasets/)
- [fmow (WILDS)](https://wilds.stanford.edu/datasets/)
## Supported Methods
Supported methods include:
- [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818)
- [Deep Adaptation Network (DAN)](https://arxiv.org/pdf/1502.02791)
- [Joint Adaptation Network (JAN)](https://arxiv.org/abs/1605.06636)
- [Conditional Domain Adversarial Network (CDAN)](https://arxiv.org/abs/1705.10667)
- [Margin Disparity Discrepancy (MDD)](https://arxiv.org/abs/1904.05801)
- [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence (FixMatch)](https://arxiv.org/abs/2001.07685)
## Usage
Our code is based
on [https://github.com/NVIDIA/apex/edit/master/examples/imagenet](https://github.com/NVIDIA/apex/edit/master/examples/imagenet)
. It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and
VGG, on the WILDS dataset.
Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed
precision "optimization levels" or `opt_level`s.
For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html).
The shell files give all the training scripts we use, e.g.,
```
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121
```
## Results
### Performance on WILDS-FMoW (DenseNet-121)
| Methods | Val Avg Acc | Test Avg Acc | Val Worst-region Acc | Test Worst-region Acc |
|---------|-------------|--------------|----------------------|-----------------------|
| ERM | 59.8 | 53.3 | 50.2 | 32.2 |
| DANN | 60.6 | 54.2 | 49.1 | 34.8 |
| DAN | 61.7 | 55.5 | 48.3 | 35.3 |
| JAN | 61.5 | 55.3 | 50.6 | 36.3 |
| CDAN | 60.7 | 55.0 | 47.4 | 35.5 |
| MDD | 60.1 | 55.1 | 49.3 | 35.9 |
| FixMatch| 61.1 | 55.1 | 51.8 | 37.4 |
### Performance on WILDS-IWildCAM (ResNet50)
| Methods | Val Avg Acc | Test Avg Acc | Val F1 macro | Test F1 macro |
|---------|-------------|--------------|--------------|---------------|
| ERM | 59.9 | 72.6 | 36.3 | 32.9 |
| DANN | 57.4 | 70.1 | 35.8 | 32.2 |
| DAN | 63.7 | 69.4 | 39.1 | 31.6 |
| JAN | 62.4 | 68.7 | 37.6 | 31.5 |
| CDAN | 57.6 | 71.2 | 37.0 | 30.6 |
| MDD | 58.3 | 73.5 | 35.0 | 30.0 |
### Visualization
We use tensorboard to record the training process and visualize the outputs of the models.
```
tensorboard --logdir=logs
```
### Distributed training
We uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process.
```
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 6666 erm.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 -j 8 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121_bs_128
```
## TODO
1. update experiment results
2. support DomainNet
3. support camelyon17
4. support self-training methods
5. support self-supervised methods
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{DANN,
author = {Ganin, Yaroslav and Lempitsky, Victor},
Booktitle = {ICML},
Title = {Unsupervised domain adaptation by backpropagation},
Year = {2015}
}
@inproceedings{DAN,
author = {Mingsheng Long and
Yue Cao and
Jianmin Wang and
Michael I. Jordan},
title = {Learning Transferable Features with Deep Adaptation Networks},
booktitle = {ICML},
year = {2015},
}
@inproceedings{JAN,
title={Deep transfer learning with joint adaptation networks},
author={Long, Mingsheng and Zhu, Han and Wang, Jianmin and Jordan, Michael I},
booktitle={ICML},
year={2017},
}
@inproceedings{CDAN,
author = {Mingsheng Long and
Zhangjie Cao and
Jianmin Wang and
Michael I. Jordan},
title = {Conditional Adversarial Domain Adaptation},
booktitle = {NeurIPS},
year = {2018}
}
@inproceedings{MDD,
title={Bridging theory and algorithm for domain adaptation},
author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael},
booktitle={ICML},
year={2019},
}
@inproceedings{FixMatch,
title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence},
author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang},
booktitle={NIPS},
year={2020}
}
```
================================================
FILE: examples/domain_adaptation/wilds_image_classification/cdan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import argparse
import os
import shutil
import time
import pprint
import math
from itertools import cycle
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from timm.loss.cross_entropy import LabelSmoothingCrossEntropy
import wilds
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import utils
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier as Classifier
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
from tllib.utils.metric import accuracy
def main(args):
writer = None
if args.local_rank == 0:
logger = CompleteLogger(args.log, args.phase)
if args.phase == 'train':
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# Data loading code
train_transform = utils.get_train_transform(
img_size=args.img_size,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.interpolation,
)
val_transform = utils.get_val_transform(
img_size=args.img_size,
crop_pct=args.crop_pct,
interpolation=args.interpolation,
)
if args.local_rank == 0:
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \
utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,
train_transform, val_transform, verbose=args.local_rank == 0)
# create model
if args.local_rank == 0:
if not args.scratch:
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch)
features_dim = model.features_dim
if args.randomized:
domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024, sigmoid=False)
else:
domain_discri = DomainDiscriminator(features_dim * args.num_classes, hidden_size=1024, sigmoid=False)
if args.sync_bn:
import apex
if args.local_rank == 0:
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda().to(memory_format=memory_format)
domain_discri = domain_discri.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.
optimizer = torch.optim.SGD(
model.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
(model, domain_discri), optimizer = amp.initialize([model, domain_discri], optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# Use cosine annealing learning rate strategy
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)
)
# define loss function
domain_adv = ConditionalDomainAdversarialLoss(
domain_discri, num_classes=args.num_classes, features_dim=features_dim, randomized=args.randomized,
randomized_dim=args.randomized_dim, sigmoid=False
)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
domain_adv = DDP(domain_adv, delay_allreduce=True)
# define loss function (criterion)
if args.smoothing:
criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()
else:
criterion = nn.CrossEntropyLoss().cuda()
# Data loading code
train_labeled_sampler = None
train_unlabeled_sampler = None
if args.distributed:
train_labeled_sampler = DistributedSampler(train_labeled_dataset)
train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)
train_unlabeled_loader = DataLoader(
train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
utils.validate(d, model, -1, writer, args)
return
for epoch in range(args.epochs):
if args.distributed:
train_labeled_sampler.set_epoch(epoch)
train_unlabeled_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
if args.local_rank == 0:
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv, optimizer, epoch, writer,
args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
prec1 = utils.validate(d, model, epoch, writer, args)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
def train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv,
optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
losses_trans = AverageMeter('Loss (transfer)', ':3.2f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
top1 = AverageMeter('Top 1', ':3.1f')
# switch to train mode
model.train()
end = time.time()
num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))
for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \
zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):
# compute output
n_s, n_t = len(input_s), len(input_t)
input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)
output, feature = model(input)
output_s, output_t = output.split([n_s, n_t], dim=0)
feature_s, feature_t = feature.split([n_s, n_t], dim=0)
loss_s = criterion(output_s, target_s.cuda())
loss_trans = domain_adv(output_s, feature_s, output_t, feature_t)
loss = loss_s + loss_trans * args.trade_off
# compute gradient and do SGD step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)
reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)
prec1 = utils.reduce_tensor(prec1, args.world_size)
domain_acc = domain_adv.module.domain_discriminator_accuracy
else:
reduced_loss_s = loss_s.data
reduced_loss_trans = loss_trans.data
domain_acc = domain_adv.domain_discriminator_accuracy
# to_python_float incurs a host<->device sync
losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))
losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))
domain_accs.update(to_python_float(domain_acc), input_s.size(0))
top1.update(to_python_float(prec1), input_s.size(0))
global_step = epoch * num_iterations + i
torch.cuda.synchronize()
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
if args.local_rank == 0:
writer.add_scalar('train/top1', to_python_float(prec1), global_step)
writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step)
writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step)
writer.add_figure('train/predictions vs. actuals',
utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,
metadata_s, train_labeled_loader.dataset.metadata_map),
global_step=global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t'
'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t'
'Domain Acc {domain_acc.val:.10f} ({domain_acc.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_labeled_loader),
args.world_size * args.batch_size[0] / batch_time.val,
args.world_size * args.batch_size[0] / batch_time.avg,
batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans,
domain_acc=domain_accs, top1=top1))
if __name__ == '__main__':
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='CDAN')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: fmow)')
parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ])
parser.add_argument('--test-list', nargs='+', default=["val", "test"])
parser.add_argument('--metric', default="acc_worst_region")
parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.5 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--bottleneck-dim', default=512, type=int,
help='Dimension of bottleneck')
parser.add_argument('-r', '--randomized', action='store_true',
help='using randomized multi-linear-map (default: False)')
parser.add_argument('-rd', '--randomized-dim', default=1024, type=int,
help='randomized dimension when using randomized multi-linear-map (default: 1024)')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR',
help='Initial learning rate. Will be scaled by /256: '
'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=200, type=int,
metavar='N', help='print frequency (default: 200)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument("--log", type=str, default='cdan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_image_classification/cdan.sh
================================================
CUDA_VISIBLE_DEVICES=0 python cdan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/cdan/fmow/lr_0_1_aa_v0_densenet121
CUDA_VISIBLE_DEVICES=0 python cdan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \
--deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \
--log logs/cdan/iwildcam/lr_1_deterministic
================================================
FILE: examples/domain_adaptation/wilds_image_classification/dan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import argparse
import os
import shutil
import time
import pprint
import math
from itertools import cycle
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from timm.loss.cross_entropy import LabelSmoothingCrossEntropy
import wilds
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import utils
from tllib.alignment.dan import MultipleKernelMaximumMeanDiscrepancy, ImageClassifier as Classifier
from tllib.modules.kernels import GaussianKernel
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
from tllib.utils.metric import accuracy
def main(args):
writer = None
if args.local_rank == 0:
logger = CompleteLogger(args.log, args.phase)
if args.phase == 'train':
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# Data loading code
train_transform = utils.get_train_transform(
img_size=args.img_size,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.interpolation,
)
val_transform = utils.get_val_transform(
img_size=args.img_size,
crop_pct=args.crop_pct,
interpolation=args.interpolation,
)
if args.local_rank == 0:
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \
utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,
train_transform, val_transform, verbose=args.local_rank == 0)
# create model
if args.local_rank == 0:
if not args.scratch:
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch)
if args.sync_bn:
import apex
if args.local_rank == 0:
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.
optimizer = torch.optim.SGD(
model.get_parameters(), args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# Use cosine annealing learning rate strategy
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)
)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion)
if args.smoothing:
criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()
else:
criterion = nn.CrossEntropyLoss().cuda()
# Data loading code
train_labeled_sampler = None
train_unlabeled_sampler = None
if args.distributed:
train_labeled_sampler = DistributedSampler(train_labeled_dataset)
train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)
train_unlabeled_loader = DataLoader(
train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
utils.validate(d, model, -1, writer, args)
return
# define loss function
mkmmd_loss = MultipleKernelMaximumMeanDiscrepancy(
kernels=[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],
linear=not args.non_linear
)
for epoch in range(args.epochs):
if args.distributed:
train_labeled_sampler.set_epoch(epoch)
train_unlabeled_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, train_unlabeled_loader, model, criterion, mkmmd_loss, optimizer, epoch, writer,
args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
prec1 = utils.validate(d, model, epoch, writer, args)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
def train(train_labeled_loader, train_unlabeled_loader, model, criterion, mkmmd_loss, optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
losses_trans = AverageMeter('Loss (transfer)', ':3.2f')
top1 = AverageMeter('Top 1', ':3.1f')
# switch to train mode
model.train()
end = time.time()
num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))
for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \
zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):
# compute output
n_s, n_t = len(input_s), len(input_t)
input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)
output, feature = model(input)
output_s, output_t = output.split([n_s, n_t], dim=0)
feature_s, feature_t = feature.split([n_s, n_t], dim=0)
loss_s = criterion(output_s, target_s.cuda())
loss_trans = mkmmd_loss(feature_s, feature_t)
loss = loss_s + loss_trans * args.trade_off
# compute gradient and do SGD step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)
reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)
prec1 = utils.reduce_tensor(prec1, args.world_size)
else:
reduced_loss_s = loss_s.data
reduced_loss_trans = loss_trans.data
# to_python_float incurs a host<->device sync
losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))
losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))
top1.update(to_python_float(prec1), input_s.size(0))
global_step = epoch * num_iterations + i
torch.cuda.synchronize()
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
if args.local_rank == 0:
writer.add_scalar('train/top1', to_python_float(prec1), global_step)
writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step)
writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step)
writer.add_figure('train/predictions vs. actuals',
utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,
metadata_s, train_labeled_loader.dataset.metadata_map),
global_step=global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t'
'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_labeled_loader),
args.world_size * args.batch_size[0] / batch_time.val,
args.world_size * args.batch_size[0] / batch_time.avg,
batch_time=batch_time,
loss_s=losses_s, loss_trans=losses_trans, top1=top1))
if __name__ == '__main__':
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='DAN')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: fmow)')
parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ])
parser.add_argument('--test-list', nargs='+', default=["val", "test"])
parser.add_argument('--metric', default="acc_worst_region")
parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.5 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--bottleneck-dim', default=512, type=int,
help='Dimension of bottleneck')
parser.add_argument('--non-linear', default=False, action='store_true',
help='whether not use the linear version')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR',
help='Initial learning rate. Will be scaled by /256: '
'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=200, type=int,
metavar='N', help='print frequency (default: 200)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument("--log", type=str, default='dan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_image_classification/dan.sh
================================================
CUDA_VISIBLE_DEVICES=0 python dan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/dan/fmow/lr_0_1_aa_v0_densenet121
CUDA_VISIBLE_DEVICES=0 python dan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \
--deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \
--log logs/dan/iwildcam/lr_0_3_deterministic
================================================
FILE: examples/domain_adaptation/wilds_image_classification/dann.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import argparse
import os
import shutil
import time
import pprint
import math
from itertools import cycle
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from timm.loss.cross_entropy import LabelSmoothingCrossEntropy
import wilds
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import utils
from tllib.modules.domain_discriminator import DomainDiscriminator
from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier as Classifier
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
from tllib.utils.metric import accuracy
def main(args):
writer = None
if args.local_rank == 0:
logger = CompleteLogger(args.log, args.phase)
if args.phase == 'train':
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# Data loading code
train_transform = utils.get_train_transform(
img_size=args.img_size,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.interpolation,
)
val_transform = utils.get_val_transform(
img_size=args.img_size,
crop_pct=args.crop_pct,
interpolation=args.interpolation,
)
if args.local_rank == 0:
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \
utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,
train_transform, val_transform, verbose=args.local_rank == 0)
# create model
if args.local_rank == 0:
if not args.scratch:
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch)
features_dim = model.features_dim
domain_discri = DomainDiscriminator(features_dim, hidden_size=1024, sigmoid=False)
if args.sync_bn:
import apex
if args.local_rank == 0:
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda().to(memory_format=memory_format)
domain_discri = domain_discri.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.
optimizer = torch.optim.SGD(
model.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
(model, domain_discri), optimizer = amp.initialize([model, domain_discri], optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# Use cosine annealing learning rate strategy
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)
)
# define loss function
domain_adv = DomainAdversarialLoss(domain_discri, sigmoid=False)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
domain_adv = DDP(domain_adv, delay_allreduce=True)
# define loss function (criterion)
if args.smoothing:
criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()
else:
criterion = nn.CrossEntropyLoss().cuda()
# Data loading code
train_labeled_sampler = None
train_unlabeled_sampler = None
if args.distributed:
train_labeled_sampler = DistributedSampler(train_labeled_dataset)
train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)
train_unlabeled_loader = DataLoader(
train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
utils.validate(d, model, -1, writer, args)
return
for epoch in range(args.epochs):
if args.distributed:
train_labeled_sampler.set_epoch(epoch)
train_unlabeled_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
if args.local_rank == 0:
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv, optimizer, epoch, writer,
args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
prec1 = utils.validate(d, model, epoch, writer, args)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
def train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv,
optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
losses_trans = AverageMeter('Loss (transfer)', ':3.2f')
domain_accs = AverageMeter('Domain Acc', ':3.1f')
top1 = AverageMeter('Top 1', ':3.1f')
# switch to train mode
model.train()
end = time.time()
num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))
for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \
zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):
# compute output
n_s, n_t = len(input_s), len(input_t)
input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)
output, feature = model(input)
output_s, output_t = output.split([n_s, n_t], dim=0)
feature_s, feature_t = feature.split([n_s, n_t], dim=0)
loss_s = criterion(output_s, target_s.cuda())
loss_trans = domain_adv(feature_s, feature_t)
loss = loss_s + loss_trans * args.trade_off
# compute gradient and do SGD step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)
reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)
prec1 = utils.reduce_tensor(prec1, args.world_size)
domain_acc = domain_adv.module.domain_discriminator_accuracy
else:
reduced_loss_s = loss_s.data
reduced_loss_trans = loss_trans.data
domain_acc = domain_adv.domain_discriminator_accuracy
# to_python_float incurs a host<->device sync
losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))
losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))
domain_accs.update(to_python_float(domain_acc), input_s.size(0))
top1.update(to_python_float(prec1), input_s.size(0))
global_step = epoch * num_iterations + i
torch.cuda.synchronize()
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
if args.local_rank == 0:
writer.add_scalar('train/top1', to_python_float(prec1), global_step)
writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step)
writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step)
writer.add_figure('train/predictions vs. actuals',
utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,
metadata_s, train_labeled_loader.dataset.metadata_map),
global_step=global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t'
'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t'
'Domain Acc {domain_acc.val:.10f} ({domain_acc.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_labeled_loader),
args.world_size * args.batch_size[0] / batch_time.val,
args.world_size * args.batch_size[0] / batch_time.avg,
batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans,
domain_acc=domain_accs, top1=top1))
if __name__ == '__main__':
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='DANN')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: fmow)')
parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ])
parser.add_argument('--test-list', nargs='+', default=["val", "test"])
parser.add_argument('--metric', default="acc_worst_region")
parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.5 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--bottleneck-dim', default=512, type=int,
help='Dimension of bottleneck')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR',
help='Initial learning rate. Will be scaled by /256: '
'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=200, type=int,
metavar='N', help='print frequency (default: 200)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument("--log", type=str, default='dann',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_image_classification/dann.sh
================================================
CUDA_VISIBLE_DEVICES=0 python dann.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/dann/fmow/lr_0_1_aa_v0_densenet121
CUDA_VISIBLE_DEVICES=0 python dann.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \
--deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \
--log logs/dann/iwildcam/lr_1_deterministic
================================================
FILE: examples/domain_adaptation/wilds_image_classification/erm.py
================================================
"""
Adapted from https://github.com/NVIDIA/apex/tree/master/examples
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import argparse
import os
import shutil
import time
import pprint
import math
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from timm.loss.cross_entropy import LabelSmoothingCrossEntropy
import wilds
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import utils
from tllib.modules.classifier import Classifier
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
from tllib.utils.metric import accuracy
def main(args):
writer = None
if args.local_rank == 0:
logger = CompleteLogger(args.log, args.phase)
if args.phase == 'train':
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# Data loading code
train_transform = utils.get_train_transform(
img_size=args.img_size,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.interpolation,
)
val_transform = utils.get_val_transform(
img_size=args.img_size,
crop_pct=args.crop_pct,
interpolation=args.interpolation,
)
if args.local_rank == 0:
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \
utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,
train_transform, val_transform, verbose=args.local_rank == 0)
# create model
if args.local_rank == 0:
if not args.scratch:
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
model = Classifier(backbone, args.num_classes, pool_layer=pool_layer, finetune=not args.scratch)
if args.sync_bn:
import apex
if args.local_rank == 0:
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.
optimizer = torch.optim.SGD(
model.get_parameters(), args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# Use cosine annealing learning rate strategy
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)
)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion)
if args.smoothing:
criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()
else:
criterion = nn.CrossEntropyLoss().cuda()
# Data loading code
train_labeled_sampler = None
train_unlabeled_sampler = None
if args.distributed:
train_labeled_sampler = DistributedSampler(train_labeled_dataset)
train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler)
train_unlabeled_loader = DataLoader(
train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler)
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
utils.validate(d, model, -1, writer, args)
return
for epoch in range(args.epochs):
if args.distributed:
train_labeled_sampler.set_epoch(epoch)
train_unlabeled_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
if args.local_rank == 0:
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
prec1 = utils.validate(d, model, epoch, writer, args)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
def train(train_loader, model, criterion, optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
top1 = AverageMeter('Top 1', ':3.1f')
# switch to train mode
model.train()
end = time.time()
for i, (input, target, metadata) in enumerate(train_loader):
# compute output
output, _ = model(input.cuda())
loss = criterion(output, target.cuda())
# compute gradient and do SGD step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, = accuracy(output.data, target.cuda(), topk=(1,))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
prec1 = utils.reduce_tensor(prec1, args.world_size)
else:
reduced_loss = loss.data
# to_python_float incurs a host<->device sync
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
global_step = epoch * len(train_loader) + i
torch.cuda.synchronize()
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
if args.local_rank == 0:
writer.add_scalar('train/top1', to_python_float(prec1), global_step)
writer.add_scalar("train/loss", to_python_float(reduced_loss), global_step)
writer.add_figure('train/predictions vs. actuals',
utils.plot_classes_preds(input.cpu(), target, output.cpu(), args.class_names,
metadata, train_loader.dataset.metadata_map),
global_step=global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_loader),
args.world_size * args.batch_size[0] / batch_time.val,
args.world_size * args.batch_size[0] / batch_time.avg,
batch_time=batch_time,
loss=losses, top1=top1))
if __name__ == '__main__':
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='Src Only')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: fmow)')
parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ])
parser.add_argument('--test-list', nargs='+', default=["val", "test"])
parser.add_argument('--metric', default="acc_worst_region")
parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.5 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR',
help='Initial learning rate. Will be scaled by /256: '
'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=200, type=int,
metavar='N', help='print frequency (default: 200)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument("--log", type=str, default='src_only',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_image_classification/erm.sh
================================================
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \
--deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 -p 500 --metric "F1-macro_all" \
--log logs/erm/iwildcam/lr_1_deterministic
================================================
FILE: examples/domain_adaptation/wilds_image_classification/fixmatch.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import argparse
import os
import shutil
import time
import pprint
import math
from itertools import cycle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from timm.loss.cross_entropy import LabelSmoothingCrossEntropy
import wilds
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import utils
from tllib.modules.classifier import Classifier
from tllib.vision.transforms import MultipleApply
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
from tllib.utils.metric import accuracy
class ImageClassifier(Classifier):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=512, **kwargs):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
def forward(self, x: torch.Tensor):
""""""
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
predictions = self.head(f)
return predictions
def main(args):
writer = None
if args.local_rank == 0:
logger = CompleteLogger(args.log, args.phase)
if args.phase == 'train':
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# Data loading code
weak_transform = utils.get_train_transform(
img_size=args.img_size,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=None,
auto_augment=None,
interpolation=args.interpolation,
)
strong_transform = utils.get_train_transform(
img_size=args.img_size,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.interpolation,
)
train_source_transform = strong_transform
train_target_transform = MultipleApply([weak_transform, strong_transform])
val_transform = utils.get_val_transform(
img_size=args.img_size,
crop_pct=args.crop_pct,
interpolation=args.interpolation,
)
if args.local_rank == 0:
print("train_source_transform: ", train_source_transform)
print('train_target_transform: ', train_target_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \
utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,
train_source_transform, val_transform, verbose=args.local_rank == 0,
transform_train_target=train_target_transform)
# create model
if args.local_rank == 0:
if not args.scratch:
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
model = ImageClassifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch)
if args.sync_bn:
import apex
if args.local_rank == 0:
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.
optimizer = torch.optim.SGD(
model.get_parameters(), args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# Use cosine annealing learning rate strategy
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)
)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion)
if args.smoothing:
criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()
else:
criterion = nn.CrossEntropyLoss().cuda()
# Data loading code
train_labeled_sampler = None
train_unlabeled_sampler = None
if args.distributed:
train_labeled_sampler = DistributedSampler(train_labeled_dataset)
train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)
train_unlabeled_loader = DataLoader(
train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
utils.validate(d, model, -1, writer, args)
return
for epoch in range(args.epochs):
if args.distributed:
train_labeled_sampler.set_epoch(epoch)
train_unlabeled_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
if args.local_rank == 0:
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, train_unlabeled_loader, model, criterion, optimizer, epoch, writer, args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
prec1 = utils.validate(d, model, epoch, writer, args)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
def train(train_labeled_loader, train_unlabeled_loader, model, criterion, optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
losses_self_training = AverageMeter('Loss (self training)', ':3.2f')
top1 = AverageMeter('Top 1', ':3.1f')
# switch to train mode
model.train()
end = time.time()
num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))
for i, (input_s, target_s, metadata_s), ((input_t, input_t_strong), metadata_t) in \
zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):
# compute output
n_s, n_t = len(input_s), len(input_t)
with torch.no_grad():
output_t = model(input_t.cuda())
confidence, pseudo_labels = F.softmax(output_t, dim=1).max(dim=1)
mask = (confidence > args.threshold).float()
input = torch.cat([input_s.cuda(), input_t_strong.cuda()], dim=0)
output = model(input)
output_s, output_t_strong = output.split([n_s, n_t], dim=0)
loss_s = criterion(output_s, target_s.cuda())
loss_self_training = args.trade_off * \
(F.cross_entropy(output_t_strong, pseudo_labels, reduction='none') * mask).mean()
loss = loss_s + loss_self_training
# compute gradient and do SGD step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)
reduced_loss_self_training = utils.reduce_tensor(loss_self_training.data, args.world_size)
prec1 = utils.reduce_tensor(prec1, args.world_size)
else:
reduced_loss_s = loss_s.data
reduced_loss_self_training = loss_self_training.data
# to_python_float incurs a host<->device sync
losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))
losses_self_training.update(to_python_float(reduced_loss_self_training), input_s.size(0))
top1.update(to_python_float(prec1), input_s.size(0))
global_step = epoch * num_iterations + i
torch.cuda.synchronize()
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
if args.local_rank == 0:
writer.add_scalar('train/top1', to_python_float(prec1), global_step)
writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step)
writer.add_scalar("train/loss (self training)", to_python_float(reduced_loss_self_training),
global_step)
writer.add_figure('train/predictions vs. actuals',
utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,
metadata_s, train_labeled_loader.dataset.metadata_map),
global_step=global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t'
'Loss (self training) {loss_self_training.val:.10f} ({loss_self_training.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_labeled_loader),
args.world_size * args.batch_size[0] / batch_time.val,
args.world_size * args.batch_size[0] / batch_time.avg,
batch_time=batch_time, loss_s=losses_s, loss_self_training=losses_self_training,
top1=top1))
if __name__ == '__main__':
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='FixMatch')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: fmow)')
parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ])
parser.add_argument('--test-list', nargs='+', default=["val", "test"])
parser.add_argument('--metric', default="acc_worst_region")
parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.5 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--bottleneck-dim', default=512, type=int,
help='Dimension of bottleneck')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
parser.add_argument('--threshold', default=0.7, type=float,
help='confidence threshold')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR',
help='Initial learning rate. Will be scaled by /256: '
'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=200, type=int,
metavar='N', help='print frequency (default: 200)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument("--log", type=str, default='fixmatch',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_image_classification/fixmatch.sh
================================================
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/fixmatch/fmow/lr_0_1_aa_v0_densenet121
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" \
--lr 0.3 --opt-level O1 --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 12 -b 24 24 -p 500 \
--metric "F1-macro_all" --log logs/fixmatch/iwildcam/lr_0_3_deterministic
================================================
FILE: examples/domain_adaptation/wilds_image_classification/jan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import argparse
import os
import shutil
import time
import pprint
import math
from itertools import cycle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from timm.loss.cross_entropy import LabelSmoothingCrossEntropy
import wilds
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import utils
from tllib.alignment.jan import JointMultipleKernelMaximumMeanDiscrepancy, ImageClassifier as Classifier
from tllib.modules.kernels import GaussianKernel
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
from tllib.utils.metric import accuracy
def main(args):
writer = None
if args.local_rank == 0:
logger = CompleteLogger(args.log, args.phase)
if args.phase == 'train':
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# Data loading code
train_transform = utils.get_train_transform(
img_size=args.img_size,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.interpolation,
)
val_transform = utils.get_val_transform(
img_size=args.img_size,
crop_pct=args.crop_pct,
interpolation=args.interpolation,
)
if args.local_rank == 0:
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \
utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,
train_transform, val_transform, verbose=args.local_rank == 0)
# create model
if args.local_rank == 0:
if not args.scratch:
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,
pool_layer=pool_layer, finetune=not args.scratch)
if args.sync_bn:
import apex
if args.local_rank == 0:
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.
optimizer = torch.optim.SGD(
model.get_parameters(), args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# Use cosine annealing learning rate strategy
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)
)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion)
if args.smoothing:
criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()
else:
criterion = nn.CrossEntropyLoss().cuda()
# Data loading code
train_labeled_sampler = None
train_unlabeled_sampler = None
if args.distributed:
train_labeled_sampler = DistributedSampler(train_labeled_dataset)
train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)
train_unlabeled_loader = DataLoader(
train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
utils.validate(d, model, -1, writer, args)
return
# define loss function
jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy(
kernels=(
[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)],
(GaussianKernel(sigma=0.92, track_running_stats=False),)
),
linear=args.linear
)
for epoch in range(args.epochs):
if args.distributed:
train_labeled_sampler.set_epoch(epoch)
train_unlabeled_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
if args.local_rank == 0:
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, train_unlabeled_loader, model, criterion, jmmd_loss, optimizer, epoch, writer, args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
prec1 = utils.validate(d, model, epoch, writer, args)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
def train(train_labeled_loader, train_unlabeled_loader, model, criterion, jmmd_loss, optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
losses_trans = AverageMeter('Loss (transfer)', ':3.2f')
top1 = AverageMeter('Top 1', ':3.1f')
# switch to train mode
model.train()
end = time.time()
num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))
for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \
zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):
# compute output
n_s, n_t = len(input_s), len(input_t)
input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)
output, feature = model(input)
output_s, output_t = output.split([n_s, n_t], dim=0)
feature_s, feature_t = feature.split([n_s, n_t], dim=0)
loss_s = criterion(output_s, target_s.cuda())
loss_trans = jmmd_loss(
(feature_s, F.softmax(output_s, dim=1)),
(feature_t, F.softmax(output_t, dim=1))
)
loss = loss_s + loss_trans * args.trade_off
# compute gradient and do SGD step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)
reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)
prec1 = utils.reduce_tensor(prec1, args.world_size)
else:
reduced_loss_s = loss_s.data
reduced_loss_trans = loss_trans.data
# to_python_float incurs a host<->device sync
losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))
losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))
top1.update(to_python_float(prec1), input_s.size(0))
global_step = epoch * num_iterations + i
torch.cuda.synchronize()
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
if args.local_rank == 0:
writer.add_scalar('train/top1', to_python_float(prec1), global_step)
writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step)
writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step)
writer.add_figure('train/predictions vs. actuals',
utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,
metadata_s, train_labeled_loader.dataset.metadata_map),
global_step=global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t'
'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_labeled_loader),
args.world_size * args.batch_size[0] / batch_time.val,
args.world_size * args.batch_size[0] / batch_time.avg,
batch_time=batch_time,
loss_s=losses_s, loss_trans=losses_trans, top1=top1))
if __name__ == '__main__':
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='JAN')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: fmow)')
parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ])
parser.add_argument('--test-list', nargs='+', default=["val", "test"])
parser.add_argument('--metric', default="acc_worst_region")
parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.5 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--bottleneck-dim', default=512, type=int,
help='Dimension of bottleneck')
parser.add_argument('--linear', default=False, action='store_true',
help='whether use the linear version')
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR',
help='Initial learning rate. Will be scaled by /256: '
'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=200, type=int,
metavar='N', help='print frequency (default: 200)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument("--log", type=str, default='jan',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_image_classification/jan.sh
================================================
CUDA_VISIBLE_DEVICES=0 python jan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/jan/fmow/lr_0_1_aa_v0_densenet121
CUDA_VISIBLE_DEVICES=0 python jan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \
--deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \
--log logs/jan/iwildcam/lr_0_3_deterministic
================================================
FILE: examples/domain_adaptation/wilds_image_classification/mdd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import argparse
import os
import shutil
import time
import pprint
import math
from itertools import cycle
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torchvision.models as models
from torch.utils.tensorboard import SummaryWriter
from timm.loss.cross_entropy import LabelSmoothingCrossEntropy
import wilds
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp, optimizers
from apex.multi_tensor_apply import multi_tensor_applier
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import utils
from tllib.alignment.mdd import ClassificationMarginDisparityDiscrepancy \
as MarginDisparityDiscrepancy, ImageClassifier as Classifier
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
from tllib.utils.metric import accuracy
def main(args):
writer = None
if args.local_rank == 0:
logger = CompleteLogger(args.log, args.phase)
if args.phase == 'train':
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
best_prec1 = 0
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# Data loading code
train_transform = utils.get_train_transform(
img_size=args.img_size,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.interpolation,
)
val_transform = utils.get_val_transform(
img_size=args.img_size,
crop_pct=args.crop_pct,
interpolation=args.interpolation,
)
if args.local_rank == 0:
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \
utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,
train_transform, val_transform, verbose=args.local_rank == 0)
# create model
if args.local_rank == 0:
if not args.scratch:
print("=> using pre-trained model '{}'".format(args.arch))
else:
print("=> creating model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
pool_layer = nn.Identity() if args.no_pool else None
model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim,
width=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch)
mdd = MarginDisparityDiscrepancy(args.margin)
if args.sync_bn:
import apex
if args.local_rank == 0:
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda().to(memory_format=memory_format)
# Scale learning rate based on global batch size
args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256.
optimizer = torch.optim.SGD(
model.get_parameters(), args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=True)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
# Use cosine annealing learning rate strategy
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr)
)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# define loss function (criterion)
if args.smoothing:
criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda()
else:
criterion = nn.CrossEntropyLoss().cuda()
# Data loading code
train_labeled_sampler = None
train_unlabeled_sampler = None
if args.distributed:
train_labeled_sampler = DistributedSampler(train_labeled_dataset)
train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True)
train_unlabeled_loader = DataLoader(
train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True)
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
utils.validate(d, model, -1, writer, args)
return
for epoch in range(args.epochs):
if args.distributed:
train_labeled_sampler.set_epoch(epoch)
train_unlabeled_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
if args.local_rank == 0:
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, train_unlabeled_loader, model, criterion, mdd, optimizer, epoch, writer, args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
prec1 = utils.validate(d, model, epoch, writer, args)
# remember best prec@1 and save checkpoint
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
def train(train_labeled_loader, train_unlabeled_loader, model, criterion, mdd,
optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses_s = AverageMeter('Loss (s)', ':3.2f')
losses_trans = AverageMeter('Loss (transfer)', ':3.2f')
top1 = AverageMeter('Top 1', ':3.1f')
# switch to train mode
model.train()
end = time.time()
num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader))
for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \
zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)):
# compute output
n_s, n_t = len(input_s), len(input_t)
input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0)
output, output_adv = model(input)
output_s, output_t = output.split([n_s, n_t], dim=0)
output_adv_s, output_adv_t = output_adv.split([n_s, n_t], dim=0)
loss_s = criterion(output_s, target_s.cuda())
loss_trans = -mdd(output_s, output_adv_s, output_t, output_adv_t)
loss = loss_s + loss_trans * args.trade_off
# compute gradient and do SGD step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if args.distributed:
model.module.step()
else:
model.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,))
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size)
reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size)
prec1 = utils.reduce_tensor(prec1, args.world_size)
else:
reduced_loss_s = loss_s.data
reduced_loss_trans = loss_trans.data
# to_python_float incurs a host<->device sync
losses_s.update(to_python_float(reduced_loss_s), input_s.size(0))
losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0))
top1.update(to_python_float(prec1), input_s.size(0))
global_step = epoch * num_iterations + i
torch.cuda.synchronize()
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
if args.local_rank == 0:
writer.add_scalar('train/top1', to_python_float(prec1), global_step)
writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step)
writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step)
writer.add_figure('train/predictions vs. actuals',
utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names,
metadata_s, train_labeled_loader.dataset.metadata_map),
global_step=global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t'
'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_labeled_loader),
args.world_size * args.batch_size[0] / batch_time.val,
args.world_size * args.batch_size[0] / batch_time.avg,
batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans, top1=top1))
if __name__ == '__main__':
model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))
parser = argparse.ArgumentParser(description='MDD')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: fmow)')
parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ])
parser.add_argument('--test-list', nargs='+', default=["val", "test"])
parser.add_argument('--metric', default="acc_worst_region")
parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+',
help='Image patch size (default: None => model default)')
parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
help='Random resize scale (default: 0.5 1.0)')
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
parser.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
parser.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--bottleneck-dim', default=2048, type=int)
parser.add_argument('--margin', type=float, default=4., help="margin gamma")
parser.add_argument('--trade-off', default=1., type=float,
help='the trade-off hyper-parameter for transfer loss')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR',
help='Initial learning rate. Will be scaled by /256: '
'args.lr = args.lr*float(args.batch_size*args.world_size)/256.')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-6)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N',
help='mini-batch size per process for source and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=200, type=int,
metavar='N', help='print frequency (default: 200)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument("--log", type=str, default='mdd',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_image_classification/mdd.sh
================================================
CUDA_VISIBLE_DEVICES=0 python mdd.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \
--lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/mdd/fmow/lr_0_1_aa_v0_densenet121
CUDA_VISIBLE_DEVICES=0 python mdd.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \
--deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \
--log logs/mdd/iwildcam/lr_0_3_deterministic
================================================
FILE: examples/domain_adaptation/wilds_image_classification/requirements.txt
================================================
wilds
timm
tensorflow
tensorboard
================================================
FILE: examples/domain_adaptation/wilds_image_classification/utils.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import time
import math
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image
import wilds
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
import timm
sys.path.append('../../..')
from tllib.vision.transforms import Denormalize
from tllib.utils.meter import AverageMeter, ProgressMeter
def get_model_names():
return timm.list_models()
def get_model(model_name, pretrain=True):
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=pretrain)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
return backbone
def get_dataset(dataset_name, root, unlabeled_list=("test_unlabeled",), test_list=("test",),
transform_train=None, transform_test=None, verbose=True, transform_train_target=None):
if transform_train_target is None:
transform_train_target = transform_train
labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True)
unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True)
num_classes = labeled_dataset.n_classes
train_labeled_dataset = labeled_dataset.get_subset("train", transform=transform_train)
train_unlabeled_datasets = [
unlabeled_dataset.get_subset(u, transform=transform_train_target)
for u in unlabeled_list
]
train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets)
test_datasets = [
labeled_dataset.get_subset(t, transform=transform_test)
for t in test_list
]
if dataset_name == "fmow":
from wilds.datasets.fmow_dataset import categories
class_names = categories
else:
class_names = list(range(num_classes))
if verbose:
print("Datasets")
for n, d in zip(["train"] + unlabeled_list + test_list,
[train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets):
print("\t{}:{}".format(n, len(d)))
print("\t#classes:", num_classes)
return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_classes, class_names
def collate_list(vec):
"""
Adapted from https://github.com/p-lambda/wilds
If vec is a list of Tensors, it concatenates them all along the first dimension.
If vec is a list of lists, it joins these lists together, but does not attempt to
recursively collate. This allows each element of the list to be, e.g., its own dict.
If vec is a list of dicts (with the same keys in each dict), it returns a single dict
with the same keys. For each key, it recursively collates all entries in the list.
"""
if not isinstance(vec, list):
raise TypeError("collate_list must take in a list")
elem = vec[0]
if torch.is_tensor(elem):
return torch.cat(vec)
elif isinstance(elem, list):
return [obj for sublist in vec for obj in sublist]
elif isinstance(elem, dict):
return {k: collate_list([d[k] for d in vec]) for k in elem}
else:
raise TypeError("Elements of the list to collate must be tensors or dicts.")
def get_train_transform(img_size, scale=None, ratio=None, hflip=0.5, vflip=0.,
color_jitter=0.4, auto_augment=None, interpolation='bilinear'):
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range
transforms_list = [
transforms.RandomResizedCrop(img_size, scale=scale, ratio=ratio, interpolation=_pil_interp(interpolation))]
if hflip > 0.:
transforms_list += [transforms.RandomHorizontalFlip(p=hflip)]
if vflip > 0.:
transforms_list += [transforms.RandomVerticalFlip(p=vflip)]
if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size)
else:
img_size_min = img_size
aa_params = dict(
translate_const=int(img_size_min * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in IMAGENET_DEFAULT_MEAN]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
if auto_augment.startswith('rand'):
transforms_list += [rand_augment_transform(auto_augment, aa_params)]
elif auto_augment.startswith('augmix'):
aa_params['translate_pct'] = 0.3
transforms_list += [augment_and_mix_transform(auto_augment, aa_params)]
else:
transforms_list += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
transforms_list += [transforms.ColorJitter(*color_jitter)]
transforms_list += [
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
]
return transforms.Compose(transforms_list)
def get_val_transform(img_size=224, crop_pct=None, interpolation='bilinear'):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
scale_size = tuple([int(x / crop_pct) for x in img_size])
else:
scale_size = int(math.floor(img_size / crop_pct))
return transforms.Compose([
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
def _pil_interp(method):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
def validate(val_dataset, model, epoch, writer, args):
val_sampler = None
if args.distributed:
val_sampler = DistributedSampler(val_dataset)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size[0], shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
all_y_true = []
all_y_pred = []
all_metadata = []
sampled_inputs = []
sampled_outputs = []
sampled_targets = []
sampled_metadata = []
batch_time = AverageMeter('Time', ':6.3f')
progress = ProgressMeter(
len(val_loader),
[batch_time],
prefix='Test: ')
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target, metadata) in enumerate(val_loader):
# compute output
with torch.no_grad():
output = model(input.cuda()).cpu()
all_y_true.append(target)
all_y_pred.append(output.argmax(1))
all_metadata.append(metadata)
sampled_inputs.append(input[0:1])
sampled_targets.append(target[0:1])
sampled_outputs.append(output[0:1])
sampled_metadata.append(metadata[0:1])
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and i % args.print_freq == 0:
progress.display(i)
if args.local_rank == 0:
writer.add_figure(
'test/predictions vs. actuals',
plot_classes_preds(
collate_list(sampled_inputs),
collate_list(sampled_targets),
collate_list(sampled_outputs),
args.class_names,
collate_list(sampled_metadata),
val_dataset.metadata_map,
nrows=min(int(len(val_loader) / 4), 50)
),
global_step=epoch
)
# evaluate
results = val_dataset.eval(
collate_list(all_y_pred),
collate_list(all_y_true),
collate_list(all_metadata)
)
print(results[1])
for k, v in results[0].items():
if v == 0 or "Other" in k:
continue
writer.add_scalar("test/{}".format(k), v, global_step=epoch)
return results[0][args.metric]
def reduce_tensor(tensor, world_size):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= world_size
return rt
def matplotlib_imshow(img):
"""helper function to show an image"""
img = Denormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(img)
img = np.transpose(img.numpy(), (1, 2, 0))
plt.imshow(img)
def plot_classes_preds(images, labels, outputs, class_names, metadata, metadata_map, nrows=4):
'''
Generates matplotlib Figure using a trained network, along with images
and labels from a batch, that shows the network's top prediction along
with its probability, alongside the actual label, coloring this
information based on whether the prediction was correct or not.
Uses the "images_to_probs" function.
'''
# convert output probabilities to predicted class
_, preds_tensor = torch.max(outputs, 1)
preds = np.squeeze(preds_tensor.numpy())
probs = [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, outputs)]
# plot the images in the batch, along with predicted and true labels
fig = plt.figure(figsize=(12, nrows * 4))
domains = get_domain_names(metadata, metadata_map)
for idx in np.arange(min(nrows * 4, len(images))):
ax = fig.add_subplot(nrows, 4, idx + 1, xticks=[], yticks=[])
matplotlib_imshow(images[idx])
ax.set_title("{0}, {1:.1f}%\n(label: {2}\ndomain: {3})".format(
class_names[preds[idx]],
probs[idx] * 100.0,
class_names[labels[idx]],
domains[idx],
), color=("green" if preds[idx] == labels[idx].item() else "red"))
return fig
def get_domain_names(metadata, metadata_map):
return get_domain_ids(metadata)
def get_domain_ids(metadata):
return [int(m[0]) for m in metadata]
================================================
FILE: examples/domain_adaptation/wilds_ogb_molpcba/README.md
================================================
# Unsupervised Domain Adaptation for WILDS (Molecule classification)
## Installation
It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results.
Then, you need to run
```
pip install -r requirements.txt
```
At last, you need to install torch_sparse following `https://github.com/rusty1s/pytorch_sparse`.
## Dataset
Following datasets can be downloaded automatically:
- [OGB-MolPCBA (WILDS)](https://wilds.stanford.edu/datasets/)
## Supported Methods
TODO
## Usage
The shell files give all the training scripts we use, e.g.
```
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --lr 3e-2 -b 4096 4096 --epochs 200 \
--seed 0 --deterministic --log logs/erm/obg_lr_0_03_deterministic
```
## Results
### Performance on WILDS-OGB-MolPCBA (GIN-virtual)
| Methods | Val Avg Precision | Test Avg Precision | GPU Memory Usage(GB)|
| --- | --- | --- | --- |
| ERM | 29.0 | 28.0 | 17.8 |
### Visualization
We use tensorboard to record the training process and visualize the outputs of the models.
```
tensorboard --logdir=logs
```
================================================
FILE: examples/domain_adaptation/wilds_ogb_molpcba/erm.py
================================================
"""
@author: Jiaxin Li
@contact: thulijx@gmail.com
"""
import argparse
import shutil
import time
import pprint
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import wilds
import utils
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
def main(args):
logger = CompleteLogger(args.log, args.phase)
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
# Data loading code
# There are no well-developed data augmentation techniques for molecular graphs.
train_transform = None
val_transform = None
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \
utils.get_dataset('ogb-molpcba', args.data_dir, args.unlabeled_list, args.test_list,
train_transform, val_transform, use_unlabeled=args.use_unlabeled, verbose=True)
# create model
print("=> creating model '{}'".format(args.arch))
model = utils.get_model(args.arch, args.num_classes)
model = model.cuda().to()
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr, weight_decay=args.weight_decay
)
# Data loading code
train_labeled_sampler = None
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler,
collate_fn=train_labeled_dataset.collate)
# define loss function (criterion)
criterion = utils.reduced_bce_logit_loss
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
print(n)
utils.validate(d, model, -1, writer, args)
return
# start training
best_val_metric = 0
test_metric = 0
for epoch in range(args.epochs):
# train for one epoch
train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
print(n)
if n == 'val':
tmp_val_metric = utils.validate(d, model, epoch, writer, args)
elif n == 'test':
tmp_test_metric = utils.validate(d, model, epoch, writer, args)
# remember best mse and save checkpoint
is_best = tmp_val_metric > best_val_metric
best_val_metric = max(tmp_val_metric, best_val_metric)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
test_metric = tmp_test_metric
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
print("best val performance: {:.3f}".format(best_val_metric))
print("test performance: {:.3f}".format(test_metric))
logger.close()
writer.close()
def train(train_loader, model, criterion, optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
# switch to train mode
model.train()
end = time.time()
for i, (input, target, metadata) in enumerate(train_loader):
# compute output
output = model(input.cuda())
loss = criterion(output, target.cuda())
# compute gradient and do optimizer step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
losses.update(loss, input.size(0))
global_step = epoch * len(train_loader) + i
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
writer.add_scalar('train/loss', loss, global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'.format(
epoch, i, len(train_loader),
args.batch_size[0] / batch_time.val,
args.batch_size[0] / batch_time.avg,
batch_time=batch_time, loss=losses))
if __name__ == '__main__':
model_names = utils.get_model_names()
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='ogb-molpcba', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: ogb-molpcba)')
parser.add_argument('--unlabeled-list', nargs='+', default=[])
parser.add_argument('--test-list', nargs='+', default=['val', 'test'])
parser.add_argument('--metric', default='ap',
help='metric used to evaluate model performance. (default: average precision)')
parser.add_argument('--use-unlabeled', action='store_true',
help='Whether use unlabeled data for training or not.')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='gin_virtual',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: gin_virtual)')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Learning rate')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float,
metavar='W', help='weight decay (default: 0.0)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=200, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=50, type=int,
metavar='N', help='print frequency (default: 50)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument('--log', type=str, default='src_only',
help='Where to save logs, checkpoints and debugging images.')
parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_ogb_molpcba/erm.sh
================================================
# ogb-molpcba
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --lr 3e-2 -b 4096 4096 --epochs 200 \
--seed 0 --deterministic --log logs/erm/obg_lr_0_03_deterministic
================================================
FILE: examples/domain_adaptation/wilds_ogb_molpcba/gin.py
================================================
"""
Adapted from "https://github.com/p-lambda/wilds"
@author: Jiaxin Li
@contact: thulijx@gmail.com
"""
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_mean_pool, global_add_pool
import torch.nn.functional as F
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
__all__ = ['gin_virtual']
class GINVirtual(torch.nn.Module):
"""
Graph Isomorphism Network augmented with virtual node for multi-task binary graph classification.
Args:
num_tasks (int): number of binary label tasks. default to 128 (number of tasks of ogbg-molpcba)
num_layers (int): number of message passing layers of GNN
emb_dim (int): dimensionality of hidden channels
dropout (float): dropout ratio applied to hidden channels
Inputs:
- batched Pytorch Geometric graph object
Outputs:
- prediction (tensor): float torch tensor of shape (num_graphs, num_tasks)
"""
def __init__(self, num_tasks=128, num_layers=5, emb_dim=300, dropout=0.5):
super(GINVirtual, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
self.emb_dim = emb_dim
self.num_tasks = num_tasks
if num_tasks is None:
self.d_out = self.emb_dim
else:
self.d_out = self.num_tasks
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
# GNN to generate node embeddings
self.gnn_node = GINVirtualNode(num_layers, emb_dim, dropout=dropout)
# Pooling function to generate whole-graph embeddings
self.pool = global_mean_pool
if num_tasks is None:
self.graph_pred_linear = None
else:
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, batched_data):
h_node = self.gnn_node(batched_data)
h_graph = self.pool(h_node, batched_data.batch)
if self.graph_pred_linear is None:
return h_graph
else:
return self.graph_pred_linear(h_graph)
class GINVirtualNode(torch.nn.Module):
"""
Helper function of Graph Isomorphism Network augmented with virtual node for multi-task binary graph classification
This will generate node embeddings.
Args:
num_layers (int): number of message passing layers of GNN
emb_dim (int): dimensionality of hidden channels
dropout (float, optional): dropout ratio applied to hidden channels. Default: 0.5
Inputs:
- batched Pytorch Geometric graph object
Outputs:
- node_embedding (tensor): float torch tensor of shape (num_nodes, emb_dim)
"""
def __init__(self, num_layers, emb_dim, dropout=0.5):
super(GINVirtualNode, self).__init__()
self.num_layers = num_layers
self.dropout = dropout
if self.num_layers < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
# set the initial virtual node embedding to 0.
self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)
torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)
# List of GNNs
self.convs = torch.nn.ModuleList()
# batch norms applied to node embeddings
self.batch_norms = torch.nn.ModuleList()
# List of MLPs to transform virtual node at every layer
self.mlp_virtualnode_list = torch.nn.ModuleList()
for layer in range(num_layers):
self.convs.append(GINConv(emb_dim))
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
for layer in range(num_layers - 1):
self.mlp_virtualnode_list.append(
torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(2 * emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim),
torch.nn.ReLU()))
def forward(self, batched_data):
x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch
# virtual node embeddings for graphs
virtualnode_embedding = self.virtualnode_embedding(
torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device))
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layers):
# add message from virtual nodes to graph nodes
h_list[layer] = h_list[layer] + virtualnode_embedding[batch]
# Message passing among graph nodes
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layers - 1:
# remove relu for the last layer
h = F.dropout(h, self.dropout, training=self.training)
else:
h = F.dropout(F.relu(h), self.dropout, training=self.training)
h_list.append(h)
# update the virtual nodes
if layer < self.num_layers - 1:
# add message from graph nodes to virtual nodes
virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding
# transform virtual nodes using MLP
virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp),
self.dropout, training=self.training)
node_embedding = h_list[-1]
return node_embedding
class GINConv(MessagePassing):
"""
Graph Isomorphism Network message passing.
Args:
emb_dim (int): node embedding dimensionality
Inputs:
- x (tensor): node embedding
- edge_index (tensor): edge connectivity information
- edge_attr (tensor): edge feature
Outputs:
- prediction (tensor): output node embedding
"""
def __init__(self, emb_dim):
super(GINConv, self).__init__(aggr="add")
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
self.eps = torch.nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, x, edge_index, edge_attr):
edge_embedding = self.bond_encoder(edge_attr)
out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding))
return out
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
def gin_virtual(num_tasks, dropout=0.5):
model = GINVirtual(num_tasks=num_tasks, dropout=dropout)
return model
================================================
FILE: examples/domain_adaptation/wilds_ogb_molpcba/requirements.txt
================================================
torch_geometric
wilds
tensorflow
tensorboard
ogb
================================================
FILE: examples/domain_adaptation/wilds_ogb_molpcba/utils.py
================================================
"""
@author: Jiaxin Li
@contact: thulijx@gmail.com
"""
import time
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
import wilds
sys.path.append('../../..')
import gin as models
from tllib.utils.meter import AverageMeter, ProgressMeter
def reduced_bce_logit_loss(y_pred, y_target):
"""
Every item of y_target has n elements which may be labeled by nan.
Nan values should not be used while calculating loss.
So extract elements which are not nan first, and then calculate loss.
"""
loss = nn.BCEWithLogitsLoss(reduction='none').cuda()
is_labeled = ~torch.isnan(y_target)
y_pred = y_pred[is_labeled].float()
y_target = y_target[is_labeled].float()
metrics = loss(y_pred, y_target)
return metrics.mean()
def get_dataset(dataset_name, root, unlabeled_list=('test_unlabeled',), test_list=('test',),
transform_train=None, transform_test=None, use_unlabeled=True, verbose=True):
labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True)
train_labeled_dataset = labeled_dataset.get_subset('train', transform=transform_train)
if use_unlabeled:
unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True)
train_unlabeled_datasets = [
unlabeled_dataset.get_subset(u, transform=transform_train)
for u in unlabeled_list
]
train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets)
else:
unlabeled_list = []
train_unlabeled_datasets = []
train_unlabeled_dataset = None
test_datasets = [
labeled_dataset.get_subset(t, transform=transform_test)
for t in test_list
]
if dataset_name == 'ogb-molpcba':
num_classes = labeled_dataset.y_size
else:
num_classes = labeled_dataset.n_classes
class_names = list(range(num_classes))
if verbose:
print('Datasets')
for n, d in zip(['train'] + unlabeled_list + test_list,
[train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets):
print('\t{}:{}'.format(n, len(d)))
print('\t#classes:', num_classes)
return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_classes, class_names
def get_model_names():
return sorted(name for name in models.__dict__ if
name.islower() and not name.startswith('__') and callable(models.__dict__[name]))
def get_model(arch, num_classes):
if arch in models.__dict__:
model = models.__dict__[arch](num_tasks=num_classes)
else:
raise ValueError('{} is not supported'.format(arch))
return model
def collate_list(vec):
"""
Adapted from https://github.com/p-lambda/wilds
If vec is a list of Tensors, it concatenates them all along the first dimension.
If vec is a list of lists, it joins these lists together, but does not attempt to
recursively collate. This allows each element of the list to be, e.g., its own dict.
If vec is a list of dicts (with the same keys in each dict), it returns a single dict
with the same keys. For each key, it recursively collates all entries in the list.
"""
if not isinstance(vec, list):
raise TypeError("collate_list must take in a list")
elem = vec[0]
if torch.is_tensor(elem):
return torch.cat(vec)
elif isinstance(elem, list):
return [obj for sublist in vec for obj in sublist]
elif isinstance(elem, dict):
return {k: collate_list([d[k] for d in vec]) for k in elem}
else:
raise TypeError("Elements of the list to collate must be tensors or dicts.")
def validate(val_dataset, model, epoch, writer, args):
val_sampler = None
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size[0], shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler, collate_fn=val_dataset.collate)
all_y_true = []
all_y_pred = []
all_metadata = []
batch_time = AverageMeter('Time', ':6.3f')
progress = ProgressMeter(
len(val_loader),
[batch_time],
prefix='Test: ')
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target, metadata) in enumerate(val_loader):
# compute output
with torch.no_grad():
output = model(input.cuda()).cpu()
all_y_true.append(target)
all_y_pred.append(output)
all_metadata.append(metadata)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and i % args.print_freq == 0:
progress.display(i)
# evaluate
results = val_dataset.eval(
collate_list(all_y_pred),
collate_list(all_y_true),
collate_list(all_metadata)
)
print(results[1])
for k, v in results[0].items():
if v == 0 or "Other" in k:
continue
writer.add_scalar("test/{}".format(k), v, global_step=epoch)
return results[0][args.metric]
================================================
FILE: examples/domain_adaptation/wilds_poverty/README.md
================================================
# Unsupervised Domain Adaptation for WILDS (Image Regression)
## Installation
It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results.
You need to install apex following `https://github.com/NVIDIA/apex`. Then run
```
pip install -r requirements.txt
```
## Dataset
Following datasets can be downloaded automatically:
- [PovertyMap (WILDS)](https://wilds.stanford.edu/datasets/)
## Supported Methods
TODO
## Usage
Our code is based
on [https://github.com/NVIDIA/apex/edit/master/examples/imagenet](https://github.com/NVIDIA/apex/edit/master/examples/imagenet)
. It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and
VGG, on the WILDS dataset.
Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed
precision "optimization levels" or `opt_level`s.
For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html).
The shell files give all the training scripts we use, e.g.
```
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold A \
--arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_A
```
## Results
### Performance on WILDS-PovertyMap (ResNet18-MultiSpectral)
| Method | Val Pearson r | Test Pearson r | Val Worst-U/R Pearson r | Test Worst-U/R Pearson r | GPU Memory Usage(GB) |
| --- | --- | --- | --- | --- | --- |
| ERM | 0.80 | 0.80 | 0.54 | 0.50 | 3.5 |
### Distributed training
We uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process.
```
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 erm.py /data/wilds --arch 'resnet18_ms' \
--opt-level O1 --deterministic --log logs/erm/poverty --lr 1e-3 --wd 0.0 --epochs 200 --metric r_wg --split_scheme official -b 64 64 --fold A
```
### Visualization
We use tensorboard to record the training process and visualize the outputs of the models.
```
tensorboard --logdir=logs
```
================================================
FILE: examples/domain_adaptation/wilds_poverty/erm.py
================================================
"""
@author: Jiaxin Li
@contact: thulijx@gmail.com
"""
import argparse
import os
import shutil
import time
import pprint
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.tensorboard import SummaryWriter
try:
from apex.parallel import DistributedDataParallel as DDP
from apex.fp16_utils import *
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
import utils
from utils import Regressor
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
def main(args):
writer = None
if args.local_rank == 0:
logger = CompleteLogger(args.log, args.phase)
if args.phase == 'train':
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("opt_level = {}".format(args.opt_level))
print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
args.gpu = 0
args.world_size = 1
if args.distributed:
args.gpu = args.local_rank
torch.cuda.set_device(args.gpu)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
if args.channels_last:
memory_format = torch.channels_last
else:
memory_format = torch.contiguous_format
# Data loading code
# Images in povertyMap dataset have 8 channels and traditional data augmentation
# methods have no effect on performance.
train_transform = None
val_transform = None
if args.local_rank == 0:
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_channels = \
utils.get_dataset('poverty', args.data_dir, args.unlabeled_list, args.test_list, args.split_scheme,
train_transform, val_transform, use_unlabeled=args.use_unlabeled,
verbose=args.local_rank == 0, fold=args.fold)
# create model
if args.local_rank == 0:
print("=> creating model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, args.num_channels)
pool_layer = nn.Identity() if args.no_pool else None
model = Regressor(backbone, pool_layer=pool_layer, finetune=False)
if args.sync_bn:
import apex
if args.local_rank == 0:
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)
model = model.cuda().to(memory_format=memory_format)
optimizer = torch.optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=args.lr, weight_decay=args.weight_decay)
# Initialize Amp. Amp accepts either values or strings for the optional override arguments,
# for convenient interoperation with argparse.
model, optimizer = amp.initialize(model, optimizer,
opt_level=args.opt_level,
keep_batchnorm_fp32=args.keep_batchnorm_fp32,
loss_scale=args.loss_scale
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=args.gamma, step_size=args.step_size)
# For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
# This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
# before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
# the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
if args.distributed:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
# model = DDP(model)
# delay_allreduce delays all communication to the end of the backward pass.
model = DDP(model, delay_allreduce=True)
# Data loading code
train_labeled_sampler = None
if args.distributed:
train_labeled_sampler = DistributedSampler(train_labeled_dataset)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler)
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
utils.validate(d, model, -1, writer, args)
return
# start training
best_val_metric = 0
test_metric = 0
for epoch in range(args.epochs):
if args.distributed:
train_labeled_sampler.set_epoch(epoch)
lr_scheduler.step(epoch)
if args.local_rank == 0:
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, model, optimizer, epoch, writer, args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
if args.local_rank == 0:
print(n)
if n == 'val':
tmp_val_metric = utils.validate(d, model, epoch, writer, args)
elif n == 'test':
tmp_test_metric = utils.validate(d, model, epoch, writer, args)
# remember best mse and save checkpoint
if args.local_rank == 0:
is_best = tmp_val_metric > best_val_metric
best_val_metric = max(tmp_val_metric, best_val_metric)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
test_metric = tmp_test_metric
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
print('best val performance: {:.3f}'.format(best_val_metric))
print('test performance: {:.3f}'.format(test_metric))
def train(train_loader, model, optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
# switch to train mode
model.train()
end = time.time()
for i, (input, target, metadata) in enumerate(train_loader):
# compute output
output, _ = model(input.cuda())
loss = F.mse_loss(output, target.cuda())
# compute gradient and do optimizer step
optimizer.zero_grad()
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Average loss and accuracy across processes for logging
if args.distributed:
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
else:
reduced_loss = loss.data
# to_python_float incurs a host<->device sync
losses.update(to_python_float(reduced_loss), input.size(0))
global_step = epoch * len(train_loader) + i
torch.cuda.synchronize()
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
if args.local_rank == 0:
writer.add_scalar("train/loss", to_python_float(reduced_loss), global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})'.format(
epoch, i, len(train_loader),
args.world_size * args.batch_size[0] / batch_time.val,
args.world_size * args.batch_size[0] / batch_time.avg,
batch_time=batch_time,
loss=losses))
if __name__ == '__main__':
model_names = utils.get_model_names()
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('--unlabeled-list', nargs='+', default=[])
parser.add_argument('--test-list', nargs='+', default=['val', 'test'])
parser.add_argument('--metric', default='r_wg',
help='metric used to evaluate model performance.'
'(default: worst-U/R Pearson r)')
parser.add_argument('--split-scheme', type=str,
help='Identifies how the train/val/test split is constructed.'
'Choices are dataset-specific.')
parser.add_argument('--fold', type=str, default='A', choices=['A', 'B', 'C', 'D', 'E'],
help='Fold for poverty dataset. Poverty has 5 different cross validation folds,'
'each splitting the countries differently.')
parser.add_argument('--use-unlabeled', action='store_true',
help='Whether use unlabeled data for training or not.')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18_ms',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: resnet18_ms)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor.')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Learning rate')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--gamma', type=int, default=0.96, help='parameter for StepLR scheduler')
parser.add_argument('--step-size', type=int, default=1, help='parameter for StepLR scheduler')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (64, 64))')
parser.add_argument('--print-freq', '-p', default=50, type=int,
metavar='N', help='print frequency (default: 50)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0), type=int)
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--opt-level', type=str)
parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
parser.add_argument('--loss-scale', type=str, default=None)
parser.add_argument('--channels-last', type=bool, default=False)
parser.add_argument('--log', type=str, default='src_only',
help='Where to save logs, checkpoints and debugging images.')
parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_poverty/erm.sh
================================================
# official split scheme
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold A \
--arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_A
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold B \
--arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_B
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold C \
--arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_C
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold D \
--arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_D
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold E \
--arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_E
================================================
FILE: examples/domain_adaptation/wilds_poverty/requirements.txt
================================================
wilds
tensorflow
tensorboard
================================================
FILE: examples/domain_adaptation/wilds_poverty/resnet_ms.py
================================================
"""
Modified based on torchvision.models.resnet
@author: Jiaxin Li
@contact: thulijx@gmail.com
"""
import torch.nn as nn
from torchvision import models
from torchvision.models.resnet import BasicBlock, Bottleneck
import copy
__all__ = ['resnet18_ms', 'resnet34_ms', 'resnet50_ms', 'resnet101_ms', 'resnet152_ms']
class ResNetMS(models.ResNet):
"""
ResNet with input channels parameter, without fully connected layer.
"""
def __init__(self, in_channels, *args, **kwargs):
super(ResNetMS, self).__init__(*args, **kwargs)
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self._out_features = self.fc.in_features
nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
return x
@property
def out_features(self) -> int:
"""The dimension of output features"""
return self._out_features
def copy_head(self) -> nn.Module:
"""Copy the origin fully connected layer"""
return copy.deepcopy(self.fc)
def resnet18_ms(num_channels=3):
model = ResNetMS(num_channels, BasicBlock, [2, 2, 2, 2])
return model
def resnet34_ms(num_channels=3):
model = ResNetMS(num_channels, BasicBlock, [3, 4, 6, 3])
return model
def resnet50_ms(num_channels=3):
model = ResNetMS(num_channels, Bottleneck, [3, 4, 6, 3])
return model
def resnet101_ms(num_channels=3):
model = ResNetMS(num_channels, Bottleneck, [3, 4, 23, 3])
return model
def resnet152_ms(num_channels=3):
model = ResNetMS(num_channels, Bottleneck, [3, 8, 36, 3])
return model
================================================
FILE: examples/domain_adaptation/wilds_poverty/utils.py
================================================
"""
@author: Jiaxin Li
@contact: thulijx@gmail.com
"""
import time
import sys
from typing import Tuple, Optional, List, Dict
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.utils.data import DataLoader, ConcatDataset
from torch.utils.data.distributed import DistributedSampler
import wilds
import resnet_ms as models
sys.path.append('../../..')
from tllib.utils.meter import AverageMeter, ProgressMeter
class Regressor(nn.Module):
"""A generic Regressor class for domain adaptation.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1
head (torch.nn.Module, optional): Any regressor head. Use :class:`torch.nn.Linear` by default
finetune (bool): Whether finetune the regressor or train from scratch. Default: True
.. note::
Different regressors are used in different domain adaptation algorithms to achieve better accuracy
respectively, and we provide a suggested `Regressor` for different algorithms.
Remember they are not the core of algorithms. You can implement your own `Regressor` and combine it with
the domain adaptation algorithm in this algorithm library.
.. note::
The learning rate of this regressor is set 10 times to that of the feature extractor for better accuracy
by default. If you have other optimization strategies, please over-ride :meth:`~Regressor.get_parameters`.
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
- predictions: regressor's predictions
- features: features after `bottleneck` layer and before `head` layer
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- predictions: (minibatch, `num_values`)
- features: (minibatch, `features_dim`)
"""
def __init__(self, backbone: nn.Module, bottleneck: Optional[nn.Module] = None, bottleneck_dim: Optional[int] = -1,
head: Optional[nn.Module] = None, finetune=True, pool_layer=None):
super(Regressor, self).__init__()
self.backbone = backbone
if pool_layer is None:
self.pool_layer = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
else:
self.pool_layer = pool_layer
if bottleneck is None:
self.bottleneck = nn.Identity()
self._features_dim = backbone.out_features
else:
self.bottleneck = bottleneck
assert bottleneck_dim > 0
self._features_dim = bottleneck_dim
if head is None:
self.head = nn.Linear(self._features_dim, 1)
else:
self.head = head
self.finetune = finetune
@property
def features_dim(self) -> int:
"""The dimension of features before the final `head` layer"""
return self._features_dim
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""""""
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
predictions = self.head(f)
if self.training:
return predictions, f
else:
return predictions
def get_parameters(self, base_lr=1.0) -> List[Dict]:
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
]
return params
def get_dataset(dataset_name, root, unlabeled_list=("test_unlabeled",), test_list=("test",),
split_scheme='official', transform_train=None, transform_test=None, use_unlabeled=True,
verbose=True, **kwargs):
labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, split_scheme=split_scheme, **kwargs)
train_labeled_dataset = labeled_dataset.get_subset("train", transform=transform_train)
if use_unlabeled:
unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True)
train_unlabeled_datasets = [
unlabeled_dataset.get_subset(u, transform=transform_train)
for u in unlabeled_list
]
train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets)
else:
unlabeled_list = []
train_unlabeled_datasets = []
train_unlabeled_dataset = None
test_datasets = [
labeled_dataset.get_subset(t, transform=transform_test)
for t in test_list
]
num_channels = labeled_dataset.get_input(0).size()[0]
if verbose:
print("Datasets")
for n, d in zip(["train"] + unlabeled_list + test_list,
[train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets):
print("\t{}:{}".format(n, len(d)))
return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_channels
def get_model_names():
return sorted(name for name in models.__dict__ if
name.islower() and not name.startswith('__') and callable(models.__dict__[name]))
def get_model(arch, num_channels):
if arch in models.__dict__:
model = models.__dict__[arch](num_channels=num_channels)
else:
raise ValueError('{} is not supported'.format(arch))
return model
def collate_list(vec):
"""
Adapted from https://github.com/p-lambda/wilds
If vec is a list of Tensors, it concatenates them all along the first dimension.
If vec is a list of lists, it joins these lists together, but does not attempt to
recursively collate. This allows each element of the list to be, e.g., its own dict.
If vec is a list of dicts (with the same keys in each dict), it returns a single dict
with the same keys. For each key, it recursively collates all entries in the list.
"""
if not isinstance(vec, list):
raise TypeError("collate_list must take in a list")
elem = vec[0]
if torch.is_tensor(elem):
return torch.cat(vec)
elif isinstance(elem, list):
return [obj for sublist in vec for obj in sublist]
elif isinstance(elem, dict):
return {k: collate_list([d[k] for d in vec]) for k in elem}
else:
raise TypeError("Elements of the list to collate must be tensors or dicts.")
def reduce_tensor(tensor, world_size):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= world_size
return rt
def validate(val_dataset, model, epoch, writer, args):
val_sampler = None
if args.distributed:
val_sampler = DistributedSampler(val_dataset)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size[0], shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
all_y_true = []
all_y_pred = []
all_metadata = []
batch_time = AverageMeter('Time', ':6.3f')
progress = ProgressMeter(
len(val_loader),
[batch_time],
prefix='Test: ')
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target, metadata) in enumerate(val_loader):
# compute output
with torch.no_grad():
output = model(input.cuda()).cpu()
all_y_true.append(target)
all_y_pred.append(output)
all_metadata.append(metadata)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and i % args.print_freq == 0:
progress.display(i)
if args.local_rank == 0:
# evaluate
results = val_dataset.eval(
collate_list(all_y_pred),
collate_list(all_y_true),
collate_list(all_metadata)
)
print(results[1])
for k, v in results[0].items():
if v == 0 or "Other" in k:
continue
writer.add_scalar("test/{}".format(k), v, global_step=epoch)
return results[0][args.metric]
================================================
FILE: examples/domain_adaptation/wilds_text/README.md
================================================
# Unsupervised Domain Adaptation for WILDS (Text Classification)
## Installation
It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results.
You need to run
```
pip install -r requirements.txt
```
## Dataset
Following datasets can be downloaded automatically:
- [CivilComments (WILDS)](https://wilds.stanford.edu/datasets/)
- [Amazon (WILDS)](https://wilds.stanford.edu/datasets/)
## Supported Methods
TODO
## Usage
The shell files give all the training scripts we use, e.g.
```
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "civilcomments" --unlabeled-list "extra_unlabeled" \
--uniform-over-groups --groupby-fields y black --max-token-length 300 --lr 1e-05 --metric "acc_wg" \
--seed 0 --deterministic --log logs/erm/civilcomments
```
## Results
### Performance on WILDS-CivilComments (DistilBert)
| Methods | Val Avg Acc | Val Worst-Group Acc | Test Avg Acc | Test Worst-Group Acc | GPU Memory Usage(GB)|
| --- | --- | --- | --- | --- | --- |
| ERM | 89.2 | 67.7 | 88.9 | 68.5 | 6.4 |
### Performance on WILDS-Amazon (DistilBert)
| Methods | Val Avg Acc | Test Avg Acc | Val 10% Acc | Test 10% Acc | GPU Memory Usage(GB)|
| --- | --- | --- | --- | --- | --- |
| ERM | 72.6 | 71.6 | 54.7 | 53.8 | 12.8 |
### Visualization
We use tensorboard to record the training process and visualize the outputs of the models.
```
tensorboard --logdir=logs
```
#### WILDS-CivilComments
#### WILDS-Amazon
================================================
FILE: examples/domain_adaptation/wilds_text/erm.py
================================================
"""
@author: Jiaxin Li
@contact: thulijx@gmail.com
"""
import argparse
import shutil
import time
import pprint
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.tensorboard import SummaryWriter
from transformers import AdamW, get_linear_schedule_with_warmup
import wilds
from wilds.common.grouper import CombinatorialGrouper
import utils
from tllib.utils.logger import CompleteLogger
from tllib.utils.meter import AverageMeter
from tllib.utils.metric import accuracy
def main(args):
logger = CompleteLogger(args.log, args.phase)
writer = SummaryWriter(args.log)
pprint.pprint(args)
print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
cudnn.benchmark = True
if args.deterministic:
cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(args.seed)
torch.set_printoptions(precision=10)
# Data loading code
train_transform = utils.get_transform(args.arch, args.max_token_length)
val_transform = utils.get_transform(args.arch, args.max_token_length)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_labeled_dataset, train_unlabeled_dataset, test_datasets, labeled_dataset, args.num_classes, args.class_names = \
utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list,
train_transform, val_transform, use_unlabeled=args.use_unlabeled, verbose=True)
# create model
print("=> using model '{}'".format(args.arch))
model = utils.get_model(args.arch, args.num_classes)
model = model.cuda().to()
# Data loading code
train_labeled_sampler = None
if args.uniform_over_groups:
train_grouper = CombinatorialGrouper(dataset=labeled_dataset, groupby_fields=args.groupby_fields)
groups, group_counts = train_grouper.metadata_to_group(train_labeled_dataset.metadata_array, return_counts=True)
group_weights = 1 / group_counts
weights = group_weights[groups]
train_labeled_sampler = WeightedRandomSampler(weights, len(train_labeled_dataset), replacement=True)
train_labeled_loader = DataLoader(
train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler
)
no_decay = ['bias', 'LayerNorm.weight']
decay_params = []
no_decay_params = []
for names, params in model.named_parameters():
if any(nd in names for nd in no_decay):
no_decay_params.append(params)
else:
decay_params.append(params)
params = [
{'params': decay_params, 'weight_decay': args.weight_decay},
{'params': no_decay_params, 'weight_decay': 0.0}
]
optimizer = AdamW(params, lr=args.lr)
lr_scheduler = get_linear_schedule_with_warmup(optimizer,
num_training_steps=len(train_labeled_loader) * args.epochs,
num_warmup_steps=0)
lr_scheduler.step_every_batch = True
lr_scheduler.use_metric = False
# define loss function (criterion)
criterion = nn.CrossEntropyLoss().cuda()
if args.phase == 'test':
# resume from the latest checkpoint
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
for n, d in zip(args.test_list, test_datasets):
print(n)
utils.validate(d, model, -1, writer, args)
return
best_val_metric = 0
test_metric = 0
for epoch in range(args.epochs):
lr_scheduler.step(epoch)
print(lr_scheduler.get_last_lr())
writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch)
# train for one epoch
train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args)
# evaluate on validation set
for n, d in zip(args.test_list, test_datasets):
print(n)
if n == 'val':
tmp_val_metric = utils.validate(d, model, epoch, writer, args)
elif n == 'test':
tmp_test_metric = utils.validate(d, model, epoch, writer, args)
# remember best prec@1 and save checkpoint
is_best = tmp_val_metric > best_val_metric
best_val_metric = max(tmp_val_metric, best_val_metric)
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if is_best:
test_metric = tmp_test_metric
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
print('best val performance: {:.3f}'.format(best_val_metric))
print('test performance: {:.3f}'.format(test_metric))
logger.close()
writer.close()
def train(train_loader, model, criterion, optimizer, epoch, writer, args):
batch_time = AverageMeter('Time', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
top1 = AverageMeter('Top 1', ':3.1f')
# switch to train mode
model.train()
end = time.time()
for i, (input, target, metadata) in enumerate(train_loader):
# compute output
output = model(input.cuda())
loss = criterion(output, target.cuda())
# compute gradient and do optimizer step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % args.print_freq == 0:
# Every print_freq iterations, check the loss, accuracy, and speed.
# For best performance, it doesn't make sense to print these metrics every
# iteration, since they incur an allreduce and some host<->device syncs.
# Measure accuracy
prec1, = accuracy(output.data, target.cuda(), topk=(1,))
losses.update(loss, input.size(0))
top1.update(prec1, input.size(0))
global_step = epoch * len(train_loader) + i
batch_time.update((time.time() - end) / args.print_freq)
end = time.time()
writer.add_scalar("train/loss", loss, global_step)
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Speed {3:.3f} ({4:.3f})\t'
'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, i, len(train_loader),
args.batch_size[0] / batch_time.val,
args.batch_size[0] / batch_time.avg,
batch_time=batch_time,
loss=losses, top1=top1))
if __name__ == '__main__':
model_names = utils.get_model_names()
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='civilcomments', choices=wilds.supported_datasets,
help='dataset: ' + ' | '.join(wilds.supported_datasets) +
' (default: civilcomments)')
parser.add_argument('--unlabeled-list', nargs='+', default=[])
parser.add_argument('--test-list', nargs='+', default=["val", "test"])
parser.add_argument('--metric', default='acc_wg',
help='metric used to evaluate model performance. (default: worst group accuracy)')
parser.add_argument('--uniform-over-groups', action='store_true',
help='sample examples such that batches are uniform over groups')
parser.add_argument('--groupby-fields', nargs='+',
help='Group data by given fields. It means that items which have the same'
'values in those fields should be grouped.')
parser.add_argument('--use-unlabeled', action='store_true',
help='Whether use unlabeled data for training or not.')
# model parameters
parser.add_argument('--arch', '-a', metavar='ARCH', default='distilbert-base-uncased',
choices=model_names,
help='model architecture: ' +
' | '.join(model_names) +
' (default: distilbert-base-uncased)')
parser.add_argument('--max-token-length', type=int, default=300,
help='The maximum size of a sequence.')
# Learning rate schedule parameters
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='Learning rate.')
parser.add_argument('--weight-decay', '--wd', default=0.01, type=float,
metavar='W', help='weight decay (default: 0.01)')
# training parameters
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=5, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=(16, 16), type=int, nargs='+',
metavar='N', help='mini-batch size per process for source'
' and target domain (default: (16, 16))')
parser.add_argument('--print-freq', '-p', default=200, type=int,
metavar='N', help='print frequency (default: 200)')
parser.add_argument('--deterministic', action='store_true')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument('--log', type=str, default='src_only',
help='Where to save logs, checkpoints and debugging images.')
parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis'm only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_adaptation/wilds_text/erm.sh
================================================
# civilcomments
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "civilcomments" --unlabeled-list "extra_unlabeled" \
--uniform-over-groups --groupby-fields y black --max-token-length 300 --lr 1e-05 --metric "acc_wg" \
--seed 0 --deterministic --log logs/erm/civilcomments
# amazon
CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "amazon" --max-token-length 512 \
--lr 1e-5 -b 24 24 --epochs 3 --metric "10th_percentile_acc" --seed 0 --deterministic --log logs/erm/amazon
================================================
FILE: examples/domain_adaptation/wilds_text/requirements.txt
================================================
wilds
tensorflow
tensorboard
transformers
================================================
FILE: examples/domain_adaptation/wilds_text/utils.py
================================================
"""
@author: Jiaxin Li
@contact: thulijx@gmail.com
"""
import time
import sys
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, ConcatDataset
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification
import wilds
sys.path.append('../../..')
from tllib.utils.meter import AverageMeter, ProgressMeter
class DistilBertClassifier(DistilBertForSequenceClassification):
"""
Adapted from https://github.com/p-lambda/wilds
"""
def __call__(self, x):
input_ids = x[:, :, 0]
attention_mask = x[:, :, 1]
outputs = super().__call__(
input_ids=input_ids,
attention_mask=attention_mask,
)[0]
return outputs
def get_transform(arch, max_token_length):
"""
Adapted from https://github.com/p-lambda/wilds
"""
if arch == 'distilbert-base-uncased':
tokenizer = DistilBertTokenizerFast.from_pretrained(arch)
else:
raise ValueError("Model: {arch} not recognized".format(arch))
def transform(text):
tokens = tokenizer(text, padding='max_length', truncation=True,
max_length=max_token_length, return_tensors='pt')
if arch == 'bert_base_uncased':
x = torch.stack(
(
tokens["input_ids"],
tokens["attention_mask"],
tokens["token_type_ids"],
),
dim=2,
)
elif arch == 'distilbert-base-uncased':
x = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2)
x = torch.squeeze(x, dim=0) # First shape dim is always 1
return x
return transform
def get_dataset(dataset_name, root, unlabeled_list=('extra_unlabeled',), test_list=('test',),
transform_train=None, transform_test=None, use_unlabeled=True, verbose=True):
labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True)
train_labeled_dataset = labeled_dataset.get_subset('train', transform=transform_train)
if use_unlabeled:
unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True)
train_unlabeled_datasets = [
unlabeled_dataset.get_subset(u, transform=transform_train)
for u in unlabeled_list
]
train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets)
else:
unlabeled_list = []
train_unlabeled_datasets = []
train_unlabeled_dataset = None
test_datasets = [
labeled_dataset.get_subset(t, transform=transform_test)
for t in test_list
]
num_classes = labeled_dataset.n_classes
class_names = list(range(num_classes))
if verbose:
print('Datasets')
for n, d in zip(['train'] + unlabeled_list + test_list,
[train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets):
print('\t{}:{}'.format(n, len(d)))
print('\t#classes:', num_classes)
return train_labeled_dataset, train_unlabeled_dataset, test_datasets, labeled_dataset, num_classes, class_names
def get_model_names():
return ['distilbert-base-uncased']
def get_model(arch, num_classes):
if arch == 'distilbert-base-uncased':
model = DistilBertClassifier.from_pretrained(arch, num_labels=num_classes)
else:
raise ValueError('{} is not supported'.format(arch))
return model
def reduce_tensor(tensor, world_size):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= world_size
return rt
def collate_list(vec):
"""
Adapted from https://github.com/p-lambda/wilds
If vec is a list of Tensors, it concatenates them all along the first dimension.
If vec is a list of lists, it joins these lists together, but does not attempt to
recursively collate. This allows each element of the list to be, e.g., its own dict.
If vec is a list of dicts (with the same keys in each dict), it returns a single dict
with the same keys. For each key, it recursively collates all entries in the list.
"""
if not isinstance(vec, list):
raise TypeError("collate_list must take in a list")
elem = vec[0]
if torch.is_tensor(elem):
return torch.cat(vec)
elif isinstance(elem, list):
return [obj for sublist in vec for obj in sublist]
elif isinstance(elem, dict):
return {k: collate_list([d[k] for d in vec]) for k in elem}
else:
raise TypeError("Elements of the list to collate must be tensors or dicts.")
def validate(val_dataset, model, epoch, writer, args):
val_sampler = None
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size[0], shuffle=False,
num_workers=args.workers, pin_memory=True, sampler=val_sampler)
all_y_true = []
all_y_pred = []
all_metadata = []
batch_time = AverageMeter('Time', ':6.3f')
progress = ProgressMeter(
len(val_loader),
[batch_time],
prefix='Test: ')
# switch to evaluate mode
model.eval()
end = time.time()
for i, (input, target, metadata) in enumerate(val_loader):
# compute output
with torch.no_grad():
output = model(input.cuda()).cpu()
all_y_true.append(target)
all_y_pred.append(output.argmax(1))
all_metadata.append(metadata)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and i % args.print_freq == 0:
progress.display(i)
# evaluate
results = val_dataset.eval(
collate_list(all_y_pred),
collate_list(all_y_true),
collate_list(all_metadata)
)
print(results[1])
for k, v in results[0].items():
if v == 0 or "Other" in k:
continue
writer.add_scalar("test/{}".format(k), v, global_step=epoch)
return results[0][args.metric]
================================================
FILE: examples/domain_generalization/image_classification/README.md
================================================
# Domain Generalization for Image Classification
## Installation
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).
You also need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
Following datasets can be downloaded automatically:
- [Office31](https://www.cc.gatech.edu/~judy/domainadapt/)
- [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html)
- [DomainNet](http://ai.bu.edu/M3SDA/)
- [PACS](https://domaingeneralization.github.io/#data)
## Supported Methods
- [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)
- [Domain Generalization with MixStyle (MixStyle, 2021 ICLR)](https://arxiv.org/abs/2104.02008)
- [Learning to Generalize: Meta-Learning for Domain Generalization (MLDG, 2018 AAAI)](https://arxiv.org/pdf/1710.03463.pdf)
- [Invariant Risk Minimization (IRM)](https://arxiv.org/abs/1907.02893)
- [Out-of-Distribution Generalization via Risk Extrapolation (VREx, 2021 ICML)](https://arxiv.org/abs/2003.00688)
- [Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (GroupDRO)](https://arxiv.org/abs/1911.08731)
- [Deep CORAL: Correlation Alignment for Deep Domain Adaptation (Deep Coral, 2016 ECCV)](https://arxiv.org/abs/1607.01719)
## Usage
The shell files give the script to reproduce the benchmark with specified hyper-parameters.
For example, if you want to train IRM on Office-Home, use the following script
```shell script
# Train with IRM on Office-Home Ar Cl Rw -> Pr task using ResNet 50.
# Assume you have put the datasets under the path `data/office-home`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr
```
Note that ``-s`` specifies the source domain, ``-t`` specifies the target domain,
and ``--log`` specifies where to store results.
## Experiment and Results
Following [DomainBed](https://github.com/facebookresearch/DomainBed), we select hyper-parameters based on
the model's performance on `training-domain validation set` (first rule in DomainBed).
Concretely, we save model with the highest accuracy on `training-domain validation set` and then
load this checkpoint to test on the target domain.
Here are some differences between our implementation and DomainBed. For the model,
we do not freeze `BatchNorm2d` layers and do not insert additional `Dropout` layer except for `PACS` dataset.
For the optimizer, we use `SGD` with momentum by default and find this usually achieves better performance than `Adam`.
**Notations**
- ``ERM`` refers to the model trained with data from the source domain.
- ``Avg`` is the accuracy reported by `TLlib`.
### PACS accuracy on ResNet-50
| Methods | avg | A | C | P | S |
|----------|------|------|------|------|------|
| ERM | 86.4 | 88.5 | 78.4 | 97.2 | 81.4 |
| IBN | 87.8 | 88.2 | 84.5 | 97.1 | 81.4 |
| MixStyle | 87.4 | 87.8 | 82.3 | 95.0 | 84.5 |
| MLDG | 87.2 | 88.2 | 81.4 | 96.6 | 82.5 |
| IRM | 86.9 | 88.0 | 82.5 | 98.0 | 79.0 |
| VREx | 87.0 | 87.2 | 82.3 | 97.4 | 81.0 |
| GroupDRO | 87.3 | 88.9 | 81.7 | 97.8 | 80.8 |
| CORAL | 86.4 | 89.1 | 80.0 | 97.4 | 79.1 |
### Office-Home accuracy on ResNet-50
| Methods | avg | A | C | P | R |
|----------|------|------|------|------|------|
| ERM | 70.8 | 68.3 | 55.9 | 78.9 | 80.0 |
| IBN | 69.9 | 67.4 | 55.2 | 77.3 | 79.6 |
| MixStyle | 71.7 | 66.8 | 58.1 | 78.0 | 79.9 |
| MLDG | 70.3 | 65.9 | 57.6 | 78.2 | 79.6 |
| IRM | 70.3 | 66.7 | 54.8 | 78.6 | 80.9 |
| VREx | 70.2 | 66.9 | 54.9 | 78.2 | 80.9 |
| GroupDRO | 70.0 | 66.7 | 55.2 | 78.8 | 79.9 |
| CORAL | 70.9 | 68.3 | 55.4 | 78.8 | 81.0 |
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{IBN-Net,
author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang},
title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net},
booktitle = {ECCV},
year = {2018}
}
@inproceedings{mixstyle,
title={Domain Generalization with MixStyle},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
booktitle={ICLR},
year={2021}
}
@inproceedings{MLDG,
title={Learning to Generalize: Meta-Learning for Domain Generalization},
author={Li, Da and Yang, Yongxin and Song, Yi-Zhe and Hospedales, Timothy},
booktitle={AAAI},
year={2018}
}
@misc{IRM,
title={Invariant Risk Minimization},
author={Martin Arjovsky and Léon Bottou and Ishaan Gulrajani and David Lopez-Paz},
year={2020},
eprint={1907.02893},
archivePrefix={arXiv},
primaryClass={stat.ML}
}
@inproceedings{VREx,
title={Out-of-Distribution Generalization via Risk Extrapolation (REx)},
author={David Krueger and Ethan Caballero and Joern-Henrik Jacobsen and Amy Zhang and Jonathan Binas and Dinghuai Zhang and Remi Le Priol and Aaron Courville},
year={2021},
booktitle={ICML},
}
@inproceedings{GroupDRO,
title={Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization},
author={Shiori Sagawa and Pang Wei Koh and Tatsunori B. Hashimoto and Percy Liang},
year={2020},
booktitle={ICLR}
}
@inproceedings{deep_coral,
title={Deep coral: Correlation alignment for deep domain adaptation},
author={Sun, Baochen and Saenko, Kate},
booktitle={ECCV},
year={2016},
}
```
================================================
FILE: examples/domain_generalization/image_classification/coral.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.alignment.coral import CorrelationAlignmentLoss
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=True, random_gray_scale=True)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,
split='train', download=True, transform=train_transform,
seed=args.seed)
sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=args.n_domains_per_batch)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
sampler=sampler, drop_last=True)
val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',
download=True, transform=val_transform, seed=args.seed)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',
download=True, transform=val_transform, seed=args.seed)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("train_dataset_size: ", len(train_dataset))
print('val_dataset_size: ', len(val_dataset))
print("test_dataset_size: ", len(test_dataset))
train_iter = ForeverDataIterator(train_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,
finetune=args.finetune, pool_layer=pool_layer).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)
# define loss function
correlation_alignment_loss = CorrelationAlignmentLoss().to(device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)
target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)
print(len(source_feature), len(target_feature))
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_val_acc1 = 0.
best_test_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, lr_scheduler, correlation_alignment_loss, args.n_domains_per_batch,
epoch, args)
# evaluate on validation set
print("Evaluate on validation set...")
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_val_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_acc1 = max(acc1, best_val_acc1)
# evaluate on test set
print("Evaluate on test set...")
best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test acc on test set = {}".format(acc1))
print("oracle acc on test set = {}".format(best_test_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR,
correlation_alignment_loss: CorrelationAlignmentLoss, n_domains_per_batch: int, epoch: int,
args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
losses_ce = AverageMeter('CELoss', ':3.2f')
losses_penalty = AverageMeter('Penalty Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_all, labels_all, _ = next(train_iter)
x_all = x_all.to(device)
labels_all = labels_all.to(device)
# compute output
y_all, f_all = model(x_all)
# measure data loading time
data_time.update(time.time() - end)
# separate into different domains
y_all = y_all.chunk(n_domains_per_batch, dim=0)
f_all = f_all.chunk(n_domains_per_batch, dim=0)
labels_all = labels_all.chunk(n_domains_per_batch, dim=0)
loss_ce = 0
loss_penalty = 0
cls_acc = 0
for domain_i in range(n_domains_per_batch):
# cls loss
y_i, labels_i = y_all[domain_i], labels_all[domain_i]
loss_ce += F.cross_entropy(y_i, labels_i)
# update acc
cls_acc += accuracy(y_i, labels_i)[0] / n_domains_per_batch
# correlation alignment loss
for domain_j in range(domain_i + 1, n_domains_per_batch):
f_i = f_all[domain_i]
f_j = f_all[domain_j]
loss_penalty += correlation_alignment_loss(f_i, f_j)
# normalize loss
loss_ce /= n_domains_per_batch
loss_penalty /= n_domains_per_batch * (n_domains_per_batch - 1) / 2
loss = loss_ce + loss_penalty * args.trade_off
losses.update(loss.item(), x_all.size(0))
losses_ce.update(loss_ce.item(), x_all.size(0))
losses_penalty.update(loss_penalty.item(), x_all.size(0))
cls_accs.update(cls_acc.item(), x_all.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CORAL for Domain Generalization')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='PACS',
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: PACS)')
parser.add_argument('-s', '--sources', nargs='+', default=None,
help='source domain(s)')
parser.add_argument('-t', '--targets', nargs='+', default=None,
help='target domain(s)')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')
parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')
# training parameters
parser.add_argument('--trade-off', default=1, type=float,
help='the trade off hyper parameter for correlation alignment loss')
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--n-domains-per-batch', default=3, type=int,
help='number of domains in each mini-batch')
parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='coral',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/image_classification/coral.sh
================================================
#!/usr/bin/env bash
# ResNet50, PACS
CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_P
CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_A
CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_C
CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_S
# ResNet50, Office-Home
CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/coral/OfficeHome_Pr
CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/coral/OfficeHome_Rw
CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/coral/OfficeHome_Cl
CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/coral/OfficeHome_Ar
# ResNet50, DomainNet
CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_s
================================================
FILE: examples/domain_generalization/image_classification/erm.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=True, random_gray_scale=True)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,
split='train', download=True, transform=train_transform,
seed=args.seed)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',
download=True, transform=val_transform, seed=args.seed)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',
download=True, transform=val_transform, seed=args.seed)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("train_dataset_size: ", len(train_dataset))
print('val_dataset_size: ', len(val_dataset))
print("test_dataset_size: ", len(test_dataset))
train_iter = ForeverDataIterator(train_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,
finetune=args.finetune, pool_layer=pool_layer).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)
target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)
print(len(source_feature), len(target_feature))
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_val_acc1 = 0.
best_test_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
print("Evaluate on validation set...")
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_val_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_acc1 = max(acc1, best_val_acc1)
# evaluate on test set
print("Evaluate on test set...")
best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test acc on test set = {}".format(acc1))
print("oracle acc on test set = {}".format(best_test_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, epoch: int,
args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, labels, _ = next(train_iter)
x = x.to(device)
labels = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y, _ = model(x)
loss = F.cross_entropy(y, labels)
cls_acc = accuracy(y, labels)[0]
losses.update(loss.item(), x.size(0))
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Baseline for Domain Generalization')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='PACS',
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: PACS)')
parser.add_argument('-s', '--sources', nargs='+', default=None,
help='source domain(s)')
parser.add_argument('-t', '--targets', nargs='+', default=None,
help='target domain(s)')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')
parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=0, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='baseline',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/image_classification/erm.sh
================================================
#!/usr/bin/env bash
# ResNet50, PACS
CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_P
CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_A
CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_C
CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_S
# ResNet50, Office-Home
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/erm/OfficeHome_Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/erm/OfficeHome_Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/erm/OfficeHome_Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/erm/OfficeHome_Ar
# ResNet50, DomainNet
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_s
================================================
FILE: examples/domain_generalization/image_classification/groupdro.py
================================================
"""
Adapted from https://github.com/facebookresearch/DomainBed
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.reweight.groupdro import AutomaticUpdateDomainWeightModule
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=True, random_gray_scale=True)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,
split='train', download=True, transform=train_transform,
seed=args.seed)
sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, args.n_domains_per_batch)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
sampler=sampler, drop_last=True)
val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',
download=True, transform=val_transform, seed=args.seed)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',
download=True, transform=val_transform, seed=args.seed)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("train_dataset_size: ", len(train_dataset))
print('val_dataset_size: ', len(val_dataset))
print("test_dataset_size: ", len(test_dataset))
train_iter = ForeverDataIterator(train_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,
finetune=args.finetune, pool_layer=pool_layer).to(device)
num_all_domains = len(train_dataset.datasets)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)
domain_weight_module = AutomaticUpdateDomainWeightModule(num_all_domains, args.eta, device)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)
target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)
print(len(source_feature), len(target_feature))
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_val_acc1 = 0.
best_test_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, lr_scheduler, domain_weight_module, args.n_domains_per_batch, epoch,
args)
# evaluate on validation set
print("Evaluate on validation set...")
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_val_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_acc1 = max(acc1, best_val_acc1)
# evaluate on test set
print("Evaluate on test set...")
best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test acc on test set = {}".format(acc1))
print("oracle acc on test set = {}".format(best_test_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR,
domain_weight_module: AutomaticUpdateDomainWeightModule, n_domains_per_batch: int, epoch: int,
args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_all, labels_all, domain_labels = next(train_iter)
x_all = x_all.to(device)
labels_all = labels_all.to(device)
domain_labels = domain_labels.to(device)
# get selected domain idxes
domain_labels = domain_labels.chunk(n_domains_per_batch, dim=0)
sampled_domain_idxes = [domain_labels[i][0].item() for i in range(n_domains_per_batch)]
# measure data loading time
data_time.update(time.time() - end)
loss_per_domain = torch.zeros(n_domains_per_batch).to(device)
cls_acc = 0
for domain_id, (x_per_domain, labels_per_domain) in enumerate(
zip(x_all.chunk(n_domains_per_batch, dim=0), labels_all.chunk(n_domains_per_batch, dim=0))):
y_per_domain, _ = model(x_per_domain)
loss_per_domain[domain_id] = F.cross_entropy(y_per_domain, labels_per_domain)
cls_acc += accuracy(y_per_domain, labels_per_domain)[0] / n_domains_per_batch
# update domain weight
domain_weight_module.update(loss_per_domain, sampled_domain_idxes)
domain_weight = domain_weight_module.get_domain_weight(sampled_domain_idxes)
# weighted cls loss
loss = (loss_per_domain * domain_weight).sum()
losses.update(loss.item(), x_all.size(0))
cls_accs.update(cls_acc.item(), x_all.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='GroupDRO for Domain Generalization')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='PACS',
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: PACS)')
parser.add_argument('-s', '--sources', nargs='+', default=None,
help='source domain(s)')
parser.add_argument('-t', '--targets', nargs='+', default=None,
help='target domain(s)')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')
parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')
# training parameters
parser.add_argument('--eta', default=1e-2, type=float,
help='the eta hyper parameter')
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--n-domains-per-batch', default=3, type=int,
help='number of domains in each mini-batch')
parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='groupdro',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/image_classification/groupdro.sh
================================================
#!/usr/bin/env bash
# ResNet50, PACS
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_P
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_A
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_C
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_S
# ResNet50, Office-Home
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Pr
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Rw
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Cl
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Ar
# ResNet50, DomainNet
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_s
================================================
FILE: examples/domain_generalization/image_classification/ibn.sh
================================================
#!/usr/bin/env bash
# IBN_ResNet50_b, PACS
CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s A C S -t P -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_P
CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P C S -t A -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_A
CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A S -t C -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_C
CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A C -t S -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_S
# IBN_ResNet50_b, Office-Home
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Pr
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Rw
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Cl
CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Ar
# IBN_ResNet50_b, DomainNet
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_s
================================================
FILE: examples/domain_generalization/image_classification/irm.py
================================================
"""
Adapted from https://github.com/facebookresearch/DomainBed
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.autograd as autograd
import utils
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class InvariancePenaltyLoss(nn.Module):
r"""Invariance Penalty Loss from `Invariant Risk Minimization `_.
We adopt implementation from `DomainBed `_. Given classifier
output :math:`y` and ground truth :math:`labels`, we split :math:`y` into two parts :math:`y_1, y_2`, corresponding
labels are :math:`labels_1, labels_2`. Next we calculate cross entropy loss with respect to a dummy classifier
:math:`w`, resulting in :math:`grad_1, grad_2` . Invariance penalty is then :math:`grad_1*grad_2`.
Inputs:
- y: predictions from model
- labels: ground truth
Shape:
- y: :math:`(N, C)` where C means the number of classes.
- labels: :math:`(N, )` where N mean mini-batch size
"""
def __init__(self):
super(InvariancePenaltyLoss, self).__init__()
self.scale = torch.tensor(1.).requires_grad_()
def forward(self, y: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
loss_1 = F.cross_entropy(y[::2] * self.scale, labels[::2])
loss_2 = F.cross_entropy(y[1::2] * self.scale, labels[1::2])
grad_1 = autograd.grad(loss_1, [self.scale], create_graph=True)[0]
grad_2 = autograd.grad(loss_2, [self.scale], create_graph=True)[0]
penalty = torch.sum(grad_1 * grad_2)
return penalty
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=True, random_gray_scale=True)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,
split='train', download=True, transform=train_transform,
seed=args.seed)
sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=args.n_domains_per_batch)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
sampler=sampler, drop_last=True)
val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',
download=True, transform=val_transform, seed=args.seed)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',
download=True, transform=val_transform, seed=args.seed)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("train_dataset_size: ", len(train_dataset))
print('val_dataset_size: ', len(val_dataset))
print("test_dataset_size: ", len(test_dataset))
train_iter = ForeverDataIterator(train_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,
finetune=args.finetune, pool_layer=pool_layer).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)
# define loss function
invariance_penalty_loss = InvariancePenaltyLoss().to(device)
# for simplicity
assert args.anneal_iters % args.iters_per_epoch == 0
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)
target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)
print(len(source_feature), len(target_feature))
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_val_acc1 = 0.
best_test_acc1 = 0.
for epoch in range(args.epochs):
if epoch * args.iters_per_epoch == args.anneal_iters:
# reset optimizer to avoid sharp jump in gradient magnitudes
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum,
weight_decay=args.wd, nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch - args.anneal_iters)
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, lr_scheduler, invariance_penalty_loss, args.n_domains_per_batch, epoch,
args)
# evaluate on validation set
print("Evaluate on validation set...")
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_val_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_acc1 = max(acc1, best_val_acc1)
# evaluate on test set
print("Evaluate on test set...")
best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test acc on test set = {}".format(acc1))
print("oracle acc on test set = {}".format(best_test_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR,
invariance_penalty_loss: InvariancePenaltyLoss, n_domains_per_batch: int, epoch: int,
args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
losses_ce = AverageMeter('CELoss', ':3.2f')
losses_penalty = AverageMeter('Penalty Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_all, labels_all, _ = next(train_iter)
x_all = x_all.to(device)
labels_all = labels_all.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_all, _ = model(x_all)
# cls loss
loss_ce = F.cross_entropy(y_all, labels_all)
# penalty loss
loss_penalty = 0
for y_per_domain, labels_per_domain in zip(y_all.chunk(n_domains_per_batch, dim=0),
labels_all.chunk(n_domains_per_batch, dim=0)):
# normalize loss by domain num
loss_penalty += invariance_penalty_loss(y_per_domain, labels_per_domain) / n_domains_per_batch
global_iter = epoch * args.iters_per_epoch + i
if global_iter >= args.anneal_iters:
trade_off = args.trade_off
else:
trade_off = 1
loss = loss_ce + loss_penalty * trade_off
cls_acc = accuracy(y_all, labels_all)[0]
losses.update(loss.item(), x_all.size(0))
losses_ce.update(loss_ce.item(), x_all.size(0))
losses_penalty.update(loss_penalty.item(), x_all.size(0))
cls_accs.update(cls_acc.item(), x_all.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='IRM for Domain Generalization')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='PACS',
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: PACS)')
parser.add_argument('-s', '--sources', nargs='+', default=None,
help='source domain(s)')
parser.add_argument('-t', '--targets', nargs='+', default=None,
help='target domain(s)')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')
parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')
# training parameters
parser.add_argument('--trade-off', default=1, type=float,
help='the trade off hyper parameter for irm penalty')
parser.add_argument('--anneal-iters', default=500, type=int,
help='anneal iterations (trade off is set to 1 during these iterations)')
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--n-domains-per-batch', default=3, type=int,
help='number of domains in each mini-batch')
parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='irm',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/image_classification/irm.sh
================================================
#!/usr/bin/env bash
# ResNet50, PACS
CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_P
CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_A
CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_C
CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_S
# ResNet50, Office-Home
CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr
CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/irm/OfficeHome_Rw
CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/irm/OfficeHome_Cl
CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/irm/OfficeHome_Ar
# ResNet50, DomainNet
CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_s
================================================
FILE: examples/domain_generalization/image_classification/mixstyle.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
import tllib.normalization.mixstyle.resnet as models
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=True, random_gray_scale=True)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,
split='train', download=True, transform=train_transform,
seed=args.seed)
sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=2)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
sampler=sampler, drop_last=True)
val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',
download=True, transform=val_transform, seed=args.seed)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',
download=True, transform=val_transform, seed=args.seed)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("train_dataset_size: ", len(train_dataset))
print('val_dataset_size: ', len(val_dataset))
print("test_dataset_size: ", len(test_dataset))
train_iter = ForeverDataIterator(train_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = models.__dict__[args.arch](mix_layers=args.mix_layers, mix_p=args.mix_p, mix_alpha=args.mix_alpha,
pretrained=True)
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,
finetune=args.finetune, pool_layer=pool_layer).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)
target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)
print(len(source_feature), len(target_feature))
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_val_acc1 = 0.
best_test_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
print("Evaluate on validation set...")
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_val_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_acc1 = max(acc1, best_val_acc1)
# evaluate on test set
print("Evaluate on test set...")
best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test acc on test set = {}".format(acc1))
print("oracle acc on test set = {}".format(best_test_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model, optimizer,
lr_scheduler: CosineAnnealingLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, labels, _ = next(train_iter)
x = x.to(device)
labels = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y, _ = model(x)
cls_loss = F.cross_entropy(y, labels)
loss = cls_loss
cls_acc = accuracy(y, labels)[0]
losses.update(loss.item(), x.size(0))
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
parser = argparse.ArgumentParser(description='MixStyle for Domain Generalization')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='PACS',
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: PACS)')
parser.add_argument('-s', '--sources', nargs='+', default=None,
help='source domain(s)')
parser.add_argument('-t', '--targets', nargs='+', default=None,
help='target domain(s)')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--mix-layers', nargs='+', help='layers to apply MixStyle')
parser.add_argument('--mix-p', default=0.5, type=float, help='probability to apply MixStyle')
parser.add_argument('--mix-alpha', default=0.1, type=float, help='parameter alpha for beta distribution')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')
parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')
# training parameters
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='mixstyle',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/image_classification/mixstyle.sh
================================================
#!/usr/bin/env bash
# ResNet50, PACS
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s A C S -t P -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_P
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P C S -t A -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_A
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P A S -t C -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_C
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P A C -t S -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_S
# ResNet50, Office-Home
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Pr
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Rw
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Cl
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Ar
# ResNet50, DomainNet
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_s
================================================
FILE: examples/domain_generalization/image_classification/mldg.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import higher
import utils
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=True, random_gray_scale=True)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,
split='train', download=True, transform=train_transform,
seed=args.seed)
n_domains_per_batch = args.n_support_domains + args.n_query_domains
sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=n_domains_per_batch)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
sampler=sampler, drop_last=True)
val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',
download=True, transform=val_transform, seed=args.seed)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',
download=True, transform=val_transform, seed=args.seed)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("train_dataset_size: ", len(train_dataset))
print('val_dataset_size: ', len(val_dataset))
print("test_dataset_size: ", len(test_dataset))
train_iter = ForeverDataIterator(train_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,
finetune=args.finetune, pool_layer=pool_layer).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)
target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)
print(len(source_feature), len(target_feature))
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_val_acc1 = 0.
best_test_acc1 = 0.
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, lr_scheduler, epoch, n_domains_per_batch, args)
# evaluate on validation set
print("Evaluate on validation set...")
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_val_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_acc1 = max(acc1, best_val_acc1)
# evaluate on test set
print("Evaluate on test set...")
best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test acc on test set = {}".format(acc1))
print("oracle acc on test set = {}".format(best_test_acc1))
logger.close()
def random_split(x_list, labels_list, n_domains_per_batch, n_support_domains):
assert n_support_domains < n_domains_per_batch
support_domain_idxes = random.sample(range(n_domains_per_batch), n_support_domains)
support_domain_list = [(x_list[idx], labels_list[idx]) for idx in range(n_domains_per_batch) if
idx in support_domain_idxes]
query_domain_list = [(x_list[idx], labels_list[idx]) for idx in range(n_domains_per_batch) if
idx not in support_domain_idxes]
return support_domain_list, query_domain_list
def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, epoch: int,
n_domains_per_batch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, labels, _ = next(train_iter)
x = x.to(device)
labels = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# split into support domain and query domain
x_list = x.chunk(n_domains_per_batch, dim=0)
labels_list = labels.chunk(n_domains_per_batch, dim=0)
support_domain_list, query_domain_list = random_split(x_list, labels_list, n_domains_per_batch,
args.n_support_domains)
# clear grad
optimizer.zero_grad()
# compute output
with higher.innerloop_ctx(model, optimizer, copy_initial_weights=False) as (inner_model, inner_optimizer):
# perform inner optimization
for _ in range(args.inner_iters):
loss_inner = 0
for (x_s, labels_s) in support_domain_list:
y_s, _ = inner_model(x_s)
# normalize loss by support domain num
loss_inner += F.cross_entropy(y_s, labels_s) / args.n_support_domains
inner_optimizer.step(loss_inner)
# calculate outer loss
loss_outer = 0
cls_acc = 0
# loss on support domains
for (x_s, labels_s) in support_domain_list:
y_s, _ = model(x_s)
# normalize loss by support domain num
loss_outer += F.cross_entropy(y_s, labels_s) / args.n_support_domains
# loss on query domains
for (x_q, labels_q) in query_domain_list:
y_q, _ = inner_model(x_q)
# normalize loss by query domain num
loss_outer += F.cross_entropy(y_q, labels_q) * args.trade_off / args.n_query_domains
cls_acc += accuracy(y_q, labels_q)[0] / args.n_query_domains
# update statistics
losses.update(loss_outer.item(), args.batch_size)
cls_accs.update(cls_acc.item(), args.batch_size)
# compute gradient and do SGD step
loss_outer.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Meta Learning for Domain Generalization')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='PACS',
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: PACS)')
parser.add_argument('-s', '--sources', nargs='+', default=None,
help='source domain(s)')
parser.add_argument('-t', '--targets', nargs='+', default=None,
help='target domain(s)')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')
parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')
# training parameters
parser.add_argument('--n-support-domains', type=int, default=1,
help='Number of support domains sampled in each iteration')
parser.add_argument('--n-query-domains', type=int, default=2,
help='Number of query domains in each iteration')
parser.add_argument('--trade-off', type=float, default=1,
help='hyper parameter beta')
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('--inner-iters', default=1, type=int,
help='Number of iterations in inner loop')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='mldg',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/image_classification/mldg.sh
================================================
#!/usr/bin/env bash
# ResNet50, PACS
CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_P
CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_A
CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_C
CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_S
# ResNet50, Office-Home
CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Pr
CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Rw
CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Cl
CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Ar
# ResNet50, DomainNet
CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_s
================================================
FILE: examples/domain_generalization/image_classification/requirements.txt
================================================
timm
wilds
higher
================================================
FILE: examples/domain_generalization/image_classification/utils.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import copy
import random
import sys
import time
import timm
import tqdm
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Sampler, Subset, ConcatDataset
sys.path.append('../../..')
from tllib.modules import Classifier as ClassifierBase
import tllib.vision.datasets as datasets
import tllib.vision.models as models
import tllib.normalization.ibn as ibn_models
from tllib.vision.transforms import ResizeImage
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
def get_model_names():
return sorted(name for name in models.__dict__ if
name.islower() and not name.startswith("__") and callable(models.__dict__[name])) + \
sorted(name for name in ibn_models.__dict__ if
name.islower() and not name.startswith("__") and callable(ibn_models.__dict__[name])) + \
timm.list_models()
def get_model(model_name):
if model_name in models.__dict__:
# load models from tllib.vision.models
backbone = models.__dict__[model_name](pretrained=True)
elif model_name in ibn_models.__dict__:
# load models (with ibn) from tllib.normalization.ibn
backbone = ibn_models.__dict__[model_name](pretrained=True)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=True)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
return backbone
def get_dataset_names():
return sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
class ConcatDatasetWithDomainLabel(ConcatDataset):
"""ConcatDataset with domain label"""
def __init__(self, *args, **kwargs):
super(ConcatDatasetWithDomainLabel, self).__init__(*args, **kwargs)
self.index_to_domain_id = {}
domain_id = 0
start = 0
for end in self.cumulative_sizes:
for idx in range(start, end):
self.index_to_domain_id[idx] = domain_id
start = end
domain_id += 1
def __getitem__(self, index):
img, target = super(ConcatDatasetWithDomainLabel, self).__getitem__(index)
domain_id = self.index_to_domain_id[index]
return img, target, domain_id
def get_dataset(dataset_name, root, task_list, split='train', download=True, transform=None, seed=0):
assert split in ['train', 'val', 'test']
# load datasets from tllib.vision.datasets
# currently only PACS, OfficeHome and DomainNet are supported
supported_dataset = ['PACS', 'OfficeHome', 'DomainNet']
assert dataset_name in supported_dataset
dataset = datasets.__dict__[dataset_name]
train_split_list = []
val_split_list = []
test_split_list = []
# we follow DomainBed and split each dataset randomly into two parts, with 80% samples and 20% samples
# respectively, the former (larger) will be used as training set, and the latter will be used as validation set.
split_ratio = 0.8
num_classes = 0
# under domain generalization setting, we use all samples in target domain as test set
for task in task_list:
if dataset_name == 'PACS':
all_split = dataset(root=root, task=task, split='all', download=download, transform=transform)
num_classes = all_split.num_classes
elif dataset_name == 'OfficeHome':
all_split = dataset(root=root, task=task, download=download, transform=transform)
num_classes = all_split.num_classes
elif dataset_name == 'DomainNet':
train_split = dataset(root=root, task=task, split='train', download=download, transform=transform)
test_split = dataset(root=root, task=task, split='test', download=download, transform=transform)
num_classes = train_split.num_classes
all_split = ConcatDataset([train_split, test_split])
train_split, val_split = split_dataset(all_split, int(len(all_split) * split_ratio), seed)
train_split_list.append(train_split)
val_split_list.append(val_split)
test_split_list.append(all_split)
train_dataset = ConcatDatasetWithDomainLabel(train_split_list)
val_dataset = ConcatDatasetWithDomainLabel(val_split_list)
test_dataset = ConcatDatasetWithDomainLabel(test_split_list)
dataset_dict = {
'train': train_dataset,
'val': val_dataset,
'test': test_dataset
}
return dataset_dict[split], num_classes
def split_dataset(dataset, n, seed=0):
"""
Return a pair of datasets corresponding to a random split of the given
dataset, with n data points in the first dataset and the rest in the last,
using the given random seed
"""
assert (n <= len(dataset))
idxes = list(range(len(dataset)))
np.random.RandomState(seed).shuffle(idxes)
subset_1 = idxes[:n]
subset_2 = idxes[n:]
return Subset(dataset, subset_1), Subset(dataset, subset_2)
def validate(val_loader, model, args, device) -> float:
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target, _) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = F.cross_entropy(output, target)
# measure accuracy and record loss
acc1 = accuracy(output, target)[0]
losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1))
return top1.avg
def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=True,
random_gray_scale=True):
"""
resizing mode:
- default: random resized crop with scale factor(0.7, 1.0) and size 224;
- cen.crop: take the center crop of 224;
- res.|cen.crop: resize the image to 256 and take the center crop of size 224;
- res: resize the image to 224;
- res2x: resize the image to 448;
- res.|crop: resize the image to 256 and take a random crop of size 224;
- res.sma|crop: resize the image keeping its aspect ratio such that the
smaller side is 256, then take a random crop of size 224;
– inc.crop: “inception crop” from (Szegedy et al., 2015);
– cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224.
"""
if resizing == 'default':
transform = T.RandomResizedCrop(224, scale=(0.7, 1.0))
elif resizing == 'cen.crop':
transform = T.CenterCrop(224)
elif resizing == 'res.|cen.crop':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224)
])
elif resizing == 'res':
transform = ResizeImage(224)
elif resizing == 'res2x':
transform = ResizeImage(448)
elif resizing == 'res.|crop':
transform = T.Compose([
T.Resize((256, 256)),
T.RandomCrop(224)
])
elif resizing == "res.sma|crop":
transform = T.Compose([
T.Resize(256),
T.RandomCrop(224)
])
elif resizing == 'inc.crop':
transform = T.RandomResizedCrop(224)
elif resizing == 'cif.crop':
transform = T.Compose([
T.Resize((224, 224)),
T.Pad(28),
T.RandomCrop(224),
])
else:
raise NotImplementedError(resizing)
transforms = [transform]
if random_horizontal_flip:
transforms.append(T.RandomHorizontalFlip())
if random_color_jitter:
transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3))
if random_gray_scale:
transforms.append(T.RandomGrayscale())
transforms.extend([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return T.Compose(transforms)
def get_val_transform(resizing='default'):
"""
resizing mode:
- default: resize the image to 224;
- res2x: resize the image to 448;
- res.|cen.crop: resize the image to 256 and take the center crop of size 224;
"""
if resizing == 'default':
transform = ResizeImage(224)
elif resizing == 'res2x':
transform = ResizeImage(448)
elif resizing == 'res.|cen.crop':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224),
])
else:
raise NotImplementedError(resizing)
return T.Compose([
transform,
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def collect_feature(data_loader, feature_extractor: nn.Module, device: torch.device,
max_num_features=None) -> torch.Tensor:
"""
Fetch data from `data_loader`, and then use `feature_extractor` to collect features. This function is
specific for domain generalization because each element in data_loader is a tuple
(images, labels, domain_labels).
Args:
data_loader (torch.utils.data.DataLoader): Data loader.
feature_extractor (torch.nn.Module): A feature extractor.
device (torch.device)
max_num_features (int): The max number of features to return
Returns:
Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`).
"""
feature_extractor.eval()
all_features = []
with torch.no_grad():
for i, (images, target, domain_labels) in enumerate(tqdm.tqdm(data_loader)):
if max_num_features is not None and i >= max_num_features:
break
images = images.to(device)
feature = feature_extractor(images).cpu()
all_features.append(feature)
return torch.cat(all_features, dim=0)
class ImageClassifier(ClassifierBase):
"""ImageClassifier specific for reproducing results of `DomainBed `_.
You are free to freeze all `BatchNorm2d` layers and insert one additional `Dropout` layer, this can achieve better
results for some datasets like PACS but may be worse for others.
Args:
backbone (torch.nn.Module): Any backbone to extract features from data
num_classes (int): Number of classes
freeze_bn (bool, optional): whether to freeze all `BatchNorm2d` layers. Default: False
dropout_p (float, optional): dropout ratio for additional `Dropout` layer, this layer is only used when `freeze_bn` is True. Default: 0.1
"""
def __init__(self, backbone: nn.Module, num_classes: int, freeze_bn=False, dropout_p=0.1, **kwargs):
super(ImageClassifier, self).__init__(backbone, num_classes, **kwargs)
self.freeze_bn = freeze_bn
if freeze_bn:
self.feature_dropout = nn.Dropout(p=dropout_p)
def forward(self, x: torch.Tensor):
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
if self.freeze_bn:
f = self.feature_dropout(f)
predictions = self.head(f)
if self.training:
return predictions, f
else:
return predictions
def train(self, mode=True):
super(ImageClassifier, self).train(mode)
if self.freeze_bn:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
class RandomDomainSampler(Sampler):
r"""Randomly sample :math:`N` domains, then randomly select :math:`K` samples in each domain to form a mini-batch of
size :math:`N\times K`.
Args:
data_source (ConcatDataset): dataset that contains data from multiple domains
batch_size (int): mini-batch size (:math:`N\times K` here)
n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here)
"""
def __init__(self, data_source: ConcatDataset, batch_size: int, n_domains_per_batch: int):
super(Sampler, self).__init__()
self.n_domains_in_dataset = len(data_source.cumulative_sizes)
self.n_domains_per_batch = n_domains_per_batch
assert self.n_domains_in_dataset >= self.n_domains_per_batch
self.sample_idxes_per_domain = []
start = 0
for end in data_source.cumulative_sizes:
idxes = [idx for idx in range(start, end)]
self.sample_idxes_per_domain.append(idxes)
start = end
assert batch_size % n_domains_per_batch == 0
self.batch_size_per_domain = batch_size // n_domains_per_batch
self.length = len(list(self.__iter__()))
def __iter__(self):
sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain)
domain_idxes = [idx for idx in range(self.n_domains_in_dataset)]
final_idxes = []
stop_flag = False
while not stop_flag:
selected_domains = random.sample(domain_idxes, self.n_domains_per_batch)
for domain in selected_domains:
sample_idxes = sample_idxes_per_domain[domain]
if len(sample_idxes) < self.batch_size_per_domain:
selected_idxes = np.random.choice(sample_idxes, self.batch_size_per_domain, replace=True)
else:
selected_idxes = random.sample(sample_idxes, self.batch_size_per_domain)
final_idxes.extend(selected_idxes)
for idx in selected_idxes:
if idx in sample_idxes_per_domain[domain]:
sample_idxes_per_domain[domain].remove(idx)
remaining_size = len(sample_idxes_per_domain[domain])
if remaining_size < self.batch_size_per_domain:
stop_flag = True
return iter(final_idxes)
def __len__(self):
return self.length
================================================
FILE: examples/domain_generalization/image_classification/vrex.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.analysis import tsne, a_distance
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
random_color_jitter=True, random_gray_scale=True)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources,
split='train', download=True, transform=train_transform,
seed=args.seed)
sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, args.n_domains_per_batch)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
sampler=sampler, drop_last=True)
val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val',
download=True, transform=val_transform, seed=args.seed)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test',
download=True, transform=val_transform, seed=args.seed)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("train_dataset_size: ", len(train_dataset))
print('val_dataset_size: ', len(val_dataset))
print("test_dataset_size: ", len(test_dataset))
train_iter = ForeverDataIterator(train_loader)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p,
finetune=args.finetune, pool_layer=pool_layer).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch)
# for simplicity
assert args.anneal_iters % args.iters_per_epoch == 0
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# extract features from both domains
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100)
target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100)
print(len(source_feature), len(target_feature))
# plot t-SNE
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
tsne.visualize(source_feature, target_feature, tSNE_filename)
print("Saving t-SNE to", tSNE_filename)
# calculate A-distance, which is a measure for distribution discrepancy
A_distance = a_distance.calculate(source_feature, target_feature, device)
print("A-distance =", A_distance)
return
if args.phase == 'test':
acc1 = utils.validate(test_loader, classifier, args, device)
print(acc1)
return
# start training
best_val_acc1 = 0.
best_test_acc1 = 0.
for epoch in range(args.epochs):
if epoch * args.iters_per_epoch == args.anneal_iters:
# reset optimizer to avoid sharp jump in gradient magnitudes
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum,
weight_decay=args.wd, nesterov=True)
lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch - args.anneal_iters)
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, lr_scheduler, args.n_domains_per_batch, epoch, args)
# evaluate on validation set
print("Evaluate on validation set...")
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_val_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_acc1 = max(acc1, best_val_acc1)
# evaluate on test set
print("Evaluate on test set...")
best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device))
# evaluate on test set
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
acc1 = utils.validate(test_loader, classifier, args, device)
print("test acc on test set = {}".format(acc1))
print("oracle acc on test set = {}".format(best_test_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR,
n_domains_per_batch: int, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
losses_ce = AverageMeter('CELoss', ':3.2f')
losses_penalty = AverageMeter('Penalty Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x_all, labels_all, _ = next(train_iter)
x_all = x_all.to(device)
labels_all = labels_all.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_all, _ = model(x_all)
loss_ce_per_domain = torch.zeros(n_domains_per_batch).to(device)
for domain_id, (y_per_domain, labels_per_domain) in enumerate(
zip(y_all.chunk(n_domains_per_batch, dim=0), labels_all.chunk(n_domains_per_batch, dim=0))):
loss_ce_per_domain[domain_id] = F.cross_entropy(y_per_domain, labels_per_domain)
# cls loss
loss_ce = loss_ce_per_domain.mean()
# penalty loss
loss_penalty = ((loss_ce_per_domain - loss_ce) ** 2).mean()
global_iter = epoch * args.iters_per_epoch + i
if global_iter >= args.anneal_iters:
trade_off = args.trade_off
else:
trade_off = 1
loss = loss_ce + loss_penalty * trade_off
cls_acc = accuracy(y_all, labels_all)[0]
losses.update(loss.item(), x_all.size(0))
losses_ce.update(loss_ce.item(), x_all.size(0))
losses_penalty.update(loss_penalty.item(), x_all.size(0))
cls_accs.update(cls_acc.item(), x_all.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='VREx for Domain Generalization')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA', default='PACS',
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
' (default: PACS)')
parser.add_argument('-s', '--sources', nargs='+', default=None,
help='source domain(s)')
parser.add_argument('-t', '--targets', nargs='+', default=None,
help='target domain(s)')
parser.add_argument('--train-resizing', type=str, default='default')
parser.add_argument('--val-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers')
parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True')
# training parameters
parser.add_argument('--trade-off', default=3, type=float,
help='the trade off hyper parameter for vrex penalty')
parser.add_argument('--anneal-iters', default=500, type=int,
help='anneal iterations (trade off is set to 1 during these iterations)')
parser.add_argument('-b', '--batch-size', default=36, type=int,
metavar='N',
help='mini-batch size (default: 36)')
parser.add_argument('--n-domains-per-batch', default=3, type=int,
help='number of domains in each mini-batch')
parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='vrex',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/image_classification/vrex.sh
================================================
#!/usr/bin/env bash
# ResNet50, PACS
CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_P
CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_A
CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_C
CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_S
# ResNet50, Office-Home
CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Pr
CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Rw
CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Cl
CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Ar
# ResNet50, DomainNet
CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_c
CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_i
CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_p
CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_q
CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_r
CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_s
================================================
FILE: examples/domain_generalization/re_identification/README.md
================================================
# Domain Generalization for Person Re-Identification
## Installation
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You
also need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
Following datasets can be downloaded automatically:
- [Market1501](http://zheng-lab.cecs.anu.edu.au/Project/project_reid.html)
- [DukeMTMC](https://exposing.ai/duke_mtmc/)
- [MSMT17](https://arxiv.org/pdf/1711.08565.pdf)
## Supported Methods
Supported methods include:
- [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf)
- [Domain Generalization with MixStyle (MixStyle, 2021 ICLR)](https://arxiv.org/abs/2104.02008)
## Usage
The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to
train MixStyle on Market1501 -> DukeMTMC task, use the following script
```shell script
# Train MixStyle on Market1501 -> DukeMTMC task using ResNet 50.
# Assume you have put the datasets under the path `data/market1501` and `data/dukemtmc`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t DukeMTMC -a resnet50 \
--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2Duke
```
### Experiment and Results
In our experiments, we adopt modified resnet architecture from [MMT](https://arxiv.org/pdf/2001.01526.pdf>). For a fair
comparison, we use standard cross entropy loss and triplet loss in all methods.
**Notations**
- ``Avg`` means the mAP (mean average precision) reported by `TLlib`.
### Cross dataset mAP on ResNet-50
| Methods | Avg | Market2Duke | Duke2Market | Market2MSMT | MSMT2Market | Duke2MSMT | MSMT2Duke |
|----------|------|-------------|-------------|-------------|-------------|-----------|-----------|
| Baseline | 23.5 | 25.6 | 29.6 | 6.3 | 31.7 | 10.1 | 37.8 |
| IBN | 27.0 | 31.5 | 33.3 | 10.4 | 33.6 | 13.7 | 40.0 |
| MixStyle | 25.5 | 27.2 | 31.6 | 8.2 | 33.9 | 12.4 | 39.9 |
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{IBN-Net,
author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang},
title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net},
booktitle = {ECCV},
year = {2018}
}
@inproceedings{mixstyle,
title={Domain Generalization with MixStyle},
author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
booktitle={ICLR},
year={2021}
}
```
================================================
FILE: examples/domain_generalization/re_identification/baseline.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import numpy as np
import torch
import torch.nn as nn
from torch.nn import DataParallel
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.utils.data import DataLoader
import utils
from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss
from tllib.vision.models.reid.identifier import ReIdentifier
import tllib.vision.datasets.reid as datasets
from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset
from tllib.utils.scheduler import WarmupMultiStepLR
from tllib.utils.metric.reid import validate, visualize_ranked_results
from tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,
random_horizontal_flip=True,
random_color_jitter=False,
random_gray_scale=False)
val_transform = utils.get_val_transform(args.height, args.width)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
working_dir = osp.dirname(osp.abspath(__file__))
root = osp.join(working_dir, args.root)
# source dataset
source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower()))
sampler = RandomMultipleGallerySampler(source_dataset.train, args.num_instances)
train_loader = DataLoader(
convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform),
batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(
convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),
root=source_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# target dataset
target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower()))
test_loader = DataLoader(
convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),
root=target_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# create model
num_classes = source_dataset.num_train_pids
backbone = utils.get_model(args.arch)
pool_layer = nn.Identity() if args.no_pool else None
model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device)
model = DataParallel(model)
# define optimizer and learning rate scheduler
optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr,
weight_decay=args.weight_decay)
lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1,
warmup_steps=args.warmup_steps)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# plot t-SNE
utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model,
filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)
# visualize ranked results
visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device,
visualize_dir=logger.visualize_directory, width=args.width, height=args.height,
rerank=args.rerank)
return
if args.phase == 'test':
print("Test on source domain:")
validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
print("Test on target domain:")
validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
return
# define loss function
criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)
# start training
best_val_mAP = 0.
best_test_mAP = 0.
for epoch in range(args.epochs):
# print learning rate
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args)
# update learning rate
lr_scheduler.step()
if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
# evaluate on validation set
print("Validation on source domain...")
_, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device,
cmc_flag=True)
# remember best mAP and save checkpoint
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if val_mAP > best_val_mAP:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_mAP = max(val_mAP, best_val_mAP)
# evaluate on test set
print("Test on target domain...")
_, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,
cmc_flag=True, rerank=args.rerank)
best_test_mAP = max(test_mAP, best_test_mAP)
# evaluate on test set
model.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
print("Test on target domain:")
_, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,
cmc_flag=True, rerank=args.rerank)
print("test mAP on target = {}".format(test_mAP))
print("oracle mAP on target = {}".format(best_test_mAP))
logger.close()
def train(train_iter: ForeverDataIterator, model, criterion_ce: CrossEntropyLossWithLabelSmooth,
criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_ce = AverageMeter('CeLoss', ':3.2f')
losses_triplet = AverageMeter('TripletLoss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, _, labels, _ = next(train_iter)
x = x.to(device)
labels = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y, f = model(x)
# cross entropy loss
loss_ce = criterion_ce(y, labels)
# triplet loss
loss_triplet = criterion_triplet(f, f, labels)
loss = loss_ce + loss_triplet * args.trade_off
cls_acc = accuracy(y, labels)[0]
losses_ce.update(loss_ce.item(), x.size(0))
losses_triplet.update(loss_triplet.item(), x.size(0))
losses.update(loss.item(), x.size(0))
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description="Baseline for Domain Generalizable ReID")
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-s', '--source', type=str, help='source domain')
parser.add_argument('-t', '--target', type=str, help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: reid_resnet50)')
parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--rate', type=float, default=0.2)
# training parameters
parser.add_argument('--trade-off', type=float, default=1,
help='trade-off hyper parameter between cross entropy loss and triplet loss')
parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=16)
parser.add_argument('--height', type=int, default=256, help="input height")
parser.add_argument('--width', type=int, default=128, help="input width")
parser.add_argument('--num-instances', type=int, default=4,
help="each minibatch consist of "
"(batch_size // num_instances) identities, and "
"each identity has num_instances instances, "
"default: 4")
parser.add_argument('--lr', type=float, default=0.00035,
help="initial learning rate")
parser.add_argument('--weight-decay', type=float, default=5e-4)
parser.add_argument('--epochs', type=int, default=80)
parser.add_argument('--warmup-steps', type=int, default=10, help='number of warp-up steps')
parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70],
help='milestones for the learning rate decay')
parser.add_argument('--eval-step', type=int, default=40)
parser.add_argument('--iters-per-epoch', type=int, default=400)
parser.add_argument('--print-freq', type=int, default=40)
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')
parser.add_argument('--rerank', action='store_true', help="evaluation only")
parser.add_argument("--log", type=str, default='baseline',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/re_identification/baseline.sh
================================================
#!/usr/bin/env bash
# Market1501 -> Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \
--finetune --seed 0 --log logs/baseline/Market2Duke
# Duke -> Market1501
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a reid_resnet50 \
--finetune --seed 0 --log logs/baseline/Duke2Market
# Market1501 -> MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a reid_resnet50 \
--finetune --seed 0 --log logs/baseline/Market2MSMT
# MSMT -> Market1501
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a reid_resnet50 \
--finetune --seed 0 --log logs/baseline/MSMT2Market
# Duke -> MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a reid_resnet50 \
--finetune --seed 0 --log logs/baseline/Duke2MSMT
# MSMT -> Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a reid_resnet50 \
--finetune --seed 0 --log logs/baseline/MSMT2Duke
================================================
FILE: examples/domain_generalization/re_identification/ibn.sh
================================================
#!/usr/bin/env bash
# Market1501 -> Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a resnet50_ibn_a \
--finetune --seed 0 --log logs/ibn/Market2Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a resnet50_ibn_b \
--finetune --seed 0 --log logs/ibn/Market2Duke
# Duke -> Market1501
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a resnet50_ibn_a \
--finetune --seed 0 --log logs/ibn/Duke2Market
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a resnet50_ibn_b \
--finetune --seed 0 --log logs/ibn/Duke2Market
# Market1501 -> MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a resnet50_ibn_a \
--finetune --seed 0 --log logs/ibn/Market2MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a resnet50_ibn_b \
--finetune --seed 0 --log logs/ibn/Market2MSMT
# MSMT -> Market1501
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a resnet50_ibn_a \
--finetune --seed 0 --log logs/ibn/MSMT2Market
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a resnet50_ibn_b \
--finetune --seed 0 --log logs/ibn/MSMT2Market
# Duke -> MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a resnet50_ibn_a \
--finetune --seed 0 --log logs/ibn/Duke2MSMT
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a resnet50_ibn_b \
--finetune --seed 0 --log logs/ibn/Duke2MSMT
# MSMT -> Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a resnet50_ibn_a \
--finetune --seed 0 --log logs/ibn/MSMT2Duke
CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a resnet50_ibn_b \
--finetune --seed 0 --log logs/ibn/MSMT2Duke
================================================
FILE: examples/domain_generalization/re_identification/mixstyle.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os.path as osp
import numpy as np
import torch
from torch.nn import DataParallel
import torch.backends.cudnn as cudnn
from torch.optim import Adam
from torch.utils.data import DataLoader
import utils
from tllib.normalization.mixstyle.sampler import RandomDomainMultiInstanceSampler
import tllib.normalization.mixstyle.resnet as models
from tllib.vision.models.reid.identifier import ReIdentifier
from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss
import tllib.vision.datasets.reid as datasets
from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset
from tllib.vision.models.reid.resnet import ReidResNet
from tllib.utils.scheduler import WarmupMultiStepLR
from tllib.utils.metric.reid import validate, visualize_ranked_results
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing,
random_horizontal_flip=True,
random_color_jitter=False,
random_gray_scale=False)
val_transform = utils.get_val_transform(args.height, args.width)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
working_dir = osp.dirname(osp.abspath(__file__))
root = osp.join(working_dir, args.root)
# source dataset
source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower()))
sampler = RandomDomainMultiInstanceSampler(source_dataset.train, batch_size=args.batch_size, n_domains_per_batch=2,
num_instances=args.num_instances)
train_loader = DataLoader(
convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform),
batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(
convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)),
root=source_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# target dataset
target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower()))
test_loader = DataLoader(
convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)),
root=target_dataset.images_dir,
transform=val_transform),
batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True)
# create model
num_classes = source_dataset.num_train_pids
backbone = models.__dict__[args.arch](mix_layers=args.mix_layers, mix_p=args.mix_p, mix_alpha=args.mix_alpha,
resnet_class=ReidResNet, pretrained=True)
model = ReIdentifier(backbone, num_classes, finetune=args.finetune).to(device)
model = DataParallel(model)
# define optimizer and learning rate scheduler
optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr,
weight_decay=args.weight_decay)
lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1,
warmup_steps=args.warmup_steps)
# resume from the best checkpoint
if args.phase != 'train':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
model.load_state_dict(checkpoint)
# analysis the model
if args.phase == 'analysis':
# plot t-SNE
utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model,
filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device)
# visualize ranked results
visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device,
visualize_dir=logger.visualize_directory, width=args.width, height=args.height,
rerank=args.rerank)
return
if args.phase == 'test':
print("Test on source domain:")
validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
print("Test on target domain:")
validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True,
rerank=args.rerank)
return
# define loss function
criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device)
criterion_triplet = SoftTripletLoss(margin=args.margin).to(device)
# start training
best_val_mAP = 0.
best_test_mAP = 0.
for epoch in range(args.epochs):
# print learning rate
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args)
# update learning rate
lr_scheduler.step()
if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1):
# evaluate on validation set
print("Validation on source domain...")
_, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device,
cmc_flag=True)
# remember best mAP and save checkpoint
torch.save(model.state_dict(), logger.get_checkpoint_path('latest'))
if val_mAP > best_val_mAP:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_val_mAP = max(val_mAP, best_val_mAP)
# evaluate on test set
print("Test on target domain...")
_, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,
cmc_flag=True, rerank=args.rerank)
best_test_mAP = max(test_mAP, best_test_mAP)
# evaluate on test set
model.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
print("Test on target domain:")
_, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device,
cmc_flag=True, rerank=args.rerank)
print("test mAP on target = {}".format(test_mAP))
print("oracle mAP on target = {}".format(best_test_mAP))
logger.close()
def train(train_iter: ForeverDataIterator, model, criterion_ce: CrossEntropyLossWithLabelSmooth,
criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses_ce = AverageMeter('CeLoss', ':3.2f')
losses_triplet = AverageMeter('TripletLoss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, _, labels, _ = next(train_iter)
x = x.to(device)
labels = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y, f = model(x)
# cross entropy loss
loss_ce = criterion_ce(y, labels)
# triplet loss
loss_triplet = criterion_triplet(f, f, labels)
loss = loss_ce + loss_triplet * args.trade_off
cls_acc = accuracy(y, labels)[0]
losses_ce.update(loss_ce.item(), x.size(0))
losses_triplet.update(loss_triplet.item(), x.size(0))
losses.update(loss.item(), x.size(0))
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
architecture_names = sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
)
dataset_names = sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
parser = argparse.ArgumentParser(description="MixStyle for Domain Generalizable ReID")
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-s', '--source', type=str, help='source domain')
parser.add_argument('-t', '--target', type=str, help='target domain')
parser.add_argument('--train-resizing', type=str, default='default')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=architecture_names,
help='backbone architecture: ' +
' | '.join(architecture_names) +
' (default: resnet50)')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--rate', type=float, default=0.2)
parser.add_argument('--mix-layers', nargs='+', help='layers to apply MixStyle')
parser.add_argument('--mix-p', default=0.5, type=float, help='probability to apply MixStyle')
parser.add_argument('--mix-alpha', default=0.1, type=float, help='parameter alpha for beta distribution')
# training parameters
parser.add_argument('--trade-off', type=float, default=1,
help='trade-off hyper parameter between cross entropy loss and triplet loss')
parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard')
parser.add_argument('-j', '--workers', type=int, default=4)
parser.add_argument('-b', '--batch-size', type=int, default=16)
parser.add_argument('--height', type=int, default=256, help="input height")
parser.add_argument('--width', type=int, default=128, help="input width")
parser.add_argument('--num-instances', type=int, default=4,
help="each minibatch consist of "
"(batch_size // num_instances) identities, and "
"each identity has num_instances instances, "
"default: 4")
parser.add_argument('--lr', type=float, default=0.00035,
help="learning rate of new parameters, for pretrained ")
parser.add_argument('--weight-decay', type=float, default=5e-4)
parser.add_argument('--epochs', type=int, default=80)
parser.add_argument('--warmup-steps', type=int, default=10, help='number of warm-up steps')
parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70],
help='milestones for the learning rate decay')
parser.add_argument('--eval-step', type=int, default=40)
parser.add_argument('--iters-per-epoch', type=int, default=400)
parser.add_argument('--print-freq', type=int, default=40)
parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.')
parser.add_argument('--rerank', action='store_true', help="evaluation only")
parser.add_argument("--log", type=str, default='mixstyle',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/domain_generalization/re_identification/mixstyle.sh
================================================
#!/usr/bin/env bash
# Market1501 -> Duke
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t DukeMTMC -a resnet50 \
--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2Duke
# Duke -> Market1501
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s DukeMTMC -t Market1501 -a resnet50 \
--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Duke2Market
# Market1501 -> MSMT
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t MSMT17 -a resnet50 \
--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2MSMT
# MSMT -> Market1501
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s MSMT17 -t Market1501 -a resnet50 \
--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/MSMT2Market
# Duke -> MSMT
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s DukeMTMC -t MSMT17 -a resnet50 \
--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Duke2MSMT
# MSMT -> Duke
CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s MSMT17 -t DukeMTMC -a resnet50 \
--mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/MSMT2Duke
================================================
FILE: examples/domain_generalization/re_identification/requirements.txt
================================================
timm
opencv-python
================================================
FILE: examples/domain_generalization/re_identification/utils.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import sys
import timm
import torch
import torch.nn as nn
import torchvision.transforms as T
sys.path.append('../../..')
from tllib.utils.metric.reid import extract_reid_feature
from tllib.utils.analysis import tsne
import tllib.vision.models.reid as models
import tllib.normalization.ibn as ibn_models
def get_model_names():
return sorted(name for name in models.__dict__ if
name.islower() and not name.startswith("__") and callable(models.__dict__[name])) + \
sorted(name for name in ibn_models.__dict__ if
name.islower() and not name.startswith("__") and callable(ibn_models.__dict__[name])) + \
timm.list_models()
def get_model(model_name):
if model_name in models.__dict__:
# load models from tllib.vision.models
backbone = models.__dict__[model_name](pretrained=True)
elif model_name in ibn_models.__dict__:
# load models (with ibn) from tllib.normalization.ibn
backbone = ibn_models.__dict__[model_name](pretrained=True)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=True)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
return backbone
def get_train_transform(height, width, resizing='default', random_horizontal_flip=True, random_color_jitter=False,
random_gray_scale=False):
"""
resizing mode:
- default: resize the image to (height, width), zero-pad it by 10 on each size, the take a random crop of
(height, width)
- res: resize the image to(height, width)
"""
if resizing == 'default':
transform = T.Compose([
T.Resize((height, width), interpolation=3),
T.Pad(10),
T.RandomCrop((height, width))
])
elif resizing == 'res':
transform = T.Resize((height, width), interpolation=3)
else:
raise NotImplementedError(resizing)
transforms = [transform]
if random_horizontal_flip:
transforms.append(T.RandomHorizontalFlip())
if random_color_jitter:
transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3))
if random_gray_scale:
transforms.append(T.RandomGrayscale())
transforms.extend([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return T.Compose(transforms)
def get_val_transform(height, width):
return T.Compose([
T.Resize((height, width), interpolation=3),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def visualize_tsne(source_loader, target_loader, model, filename, device, n_data_points_per_domain=3000):
"""Visualize features from different domains using t-SNE. As we can have very large number of samples in each
domain, only `n_data_points_per_domain` number of samples are randomly selected in each domain.
"""
source_feature_dict = extract_reid_feature(source_loader, model, device, normalize=True)
source_feature = torch.stack(list(source_feature_dict.values())).cpu()
source_feature = source_feature[torch.randperm(len(source_feature))]
source_feature = source_feature[:n_data_points_per_domain]
target_feature_dict = extract_reid_feature(target_loader, model, device, normalize=True)
target_feature = torch.stack(list(target_feature_dict.values())).cpu()
target_feature = target_feature[torch.randperm(len(target_feature))]
target_feature = target_feature[:n_data_points_per_domain]
tsne.visualize(source_feature, target_feature, filename, source_color='cornflowerblue', target_color='darkorange')
print('T-SNE process is done, figure is saved to {}'.format(filename))
================================================
FILE: examples/model_selection/README.md
================================================
# Model Selection
## Installation
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models).
You need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
- [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)
- [Caltech101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/)
- [CIFAR10](http://www.cs.utoronto.ca/~kriz/cifar.html)
- [CIFAR100](http://www.cs.utoronto.ca/~kriz/cifar.html)
- [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html)
- [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/)
- [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)
- [SUN397](https://vision.princeton.edu/projects/2010/SUN/)
## Supported Methods
Supported methods include:
- [An Information-theoretic Approach to Transferability in Task Transfer Learning (H-Score, ICIP 2019)](http://yangli-feasibility.com/home/media/icip-19.pdf)
- [LEEP: A New Measure to Evaluate Transferability of Learned Representations (LEEP, ICML 2020)](http://proceedings.mlr.press/v119/nguyen20b/nguyen20b.pdf)
- [Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models for Transfer Learning (LogME, ICML 2021)](https://arxiv.org/pdf/2102.11005.pdf)
- [Negative Conditional Entropy in `Transferability and Hardness of Supervised Classification Tasks (NCE, ICCV 2019)](https://arxiv.org/pdf/1908.08142v1.pdf)
## Experiment and Results
### Model Ranking on image classification tasks
The shell files give the scripts to ranking pre-trained models on a given dataset. For example, if you want to use LogME to calculate the transfer performance of ResNet50(ImageNet pre-trained) on Aircraft, use the following script
```shell script
# Using LogME to ranking pre-trained ResNet50 on Aircraft
# Assume you have put the datasets under the path `data/cub200`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features
```
We use LEEP, NCE HScore and LogME to compute scores by applying 10 pre-trained models to different datasets. The correlation([Weighted kendall Tau](https://vigna.di.unimi.it/ftp/papers/WeightedTau.pdf)/Pearson Correlation) between scores and fine-tuned accuracies
are presented.
#### Model Ranking Benchmark on Aircraft
| Model | Finetuned Acc | HScore | LEEP | LogME | NCE |
|--------------|---------------|--------|--------|-------|--------|
| GoogleNet | 82.7 | 28.37 | -4.310 | 0.934 | -4.248 |
| Inception V3 | 88.8 | 43.89 | -4.202 | 0.953 | -4.170 |
| ResNet50 | 86.6 | 46.23 | -4.215 | 0.946 | -4.201 |
| ResNet101 | 85.6 | 46.13 | -4.230 | 0.948 | -4.222 |
| ResNet152 | 85.3 | 46.25 | -4.230 | 0.950 | -4.229 |
| DenseNet121 | 85.4 | 31.53 | -4.228 | 0.938 | -4.215 |
| DenseNet169 | 84.5 | 41.81 | -4.245 | 0.943 | -4.270 |
| Densenet201 | 84.6 | 46.01 | -4.206 | 0.942 | -4.189 |
| MobileNet V2 | 82.8 | 34.43 | -4.198 | 0.941 | -4.208 |
| MNasNet | 72.8 | 35.28 | -4.192 | 0.948 | -4.195 |
| Pearson Corr | - | 0.688 | 0.127 | 0.582 | 0.173 |
| Weighted Tau | - | 0.664 | -0.264 | 0.595 | 0.002 |
#### Model Ranking Benchmark on Caltech101
| Model | Finetuned Acc | HScore | LEEP | LogME | NCE |
|--------------|---------------|--------|--------|-------|--------|
| GoogleNet | 91.7 | 75.88 | -1.462 | 1.228 | -0.665 |
| Inception V3 | 94.3 | 93.73 | -1.119 | 1.387 | -0.560 |
| ResNet50 | 91.8 | 91.65 | -1.020 | 1.262 | -0.616 |
| ResNet101 | 93.1 | 92.54 | -0.899 | 1.305 | -0.603 |
| ResNet152 | 93.2 | 92.91 | -0.875 | 1.324 | -0.605 |
| DenseNet121 | 91.9 | 75.02 | -0.979 | 1.172 | -0.609 |
| DenseNet169 | 92.5 | 86.37 | -0.864 | 1.212 | -0.580 |
| Densenet201 | 93.4 | 89.90 | -0.914 | 1.228 | -0.590 |
| MobileNet V2 | 89.1 | 75.82 | -1.115 | 1.150 | -0.693 |
| MNasNet | 91.5 | 77.00 | -1.043 | 1.178 | -0.690 |
| Pearson Corr | - | 0.748 | 0.324 | 0.794 | 0.843 |
| Weighted Tau | - | 0.721 | 0.127 | 0.697 | 0.810 |
#### Model Ranking Benchmark on CIFAR10
| Model | Finetuned Acc | HScore | LEEP | LogME | NCE |
|--------------|---------------|--------|--------|-------|--------|
| GoogleNet | 96.2 | 5.911 | -1.385 | 0.293 | -1.139 |
| Inception V3 | 97.5 | 6.363 | -1.259 | 0.349 | -1.060 |
| ResNet50 | 96.8 | 6.567 | -1.010 | 0.388 | -1.007 |
| ResNet101 | 97.7 | 6.901 | -0.829 | 0.463 | -0.838 |
| ResNet152 | 97.9 | 6.945 | -0.838 | 0.469 | -0.851 |
| DenseNet121 | 97.2 | 6.210 | -1.035 | 0.302 | -1.006 |
| DenseNet169 | 97.4 | 6.547 | -0.934 | 0.343 | -0.946 |
| Densenet201 | 97.4 | 6.706 | -0.888 | 0.369 | -0.866 |
| MobileNet V2 | 95.7 | 5.928 | -1.100 | 0.291 | -1.089 |
| MNasNet | 96.8 | 6.018 | -1.066 | 0.304 | -1.086 |
| Pearson Corr | - | 0.839 | 0.604 | 0.733 | 0.786 |
| Weighted Tau | - | 0.800 | 0.638 | 0.785 | 0.714 |
#### Model Ranking Benchmark on CIFAR100
| Model | Finetuned Acc | HScore | LEEP | LogME | NCE |
|--------------|---------------|--------|--------|-------|--------|
| GoogleNet | 83.2 | 29.33 | -3.234 | 1.037 | -2.751 |
| Inception V3 | 86.6 | 36.47 | -2.995 | 1.070 | -2.615 |
| ResNet50 | 84.5 | 40.20 | -2.612 | 1.099 | -2.516 |
| ResNet101 | 87.0 | 43.80 | -2.365 | 1.130 | -2.285 |
| ResNet152 | 87.6 | 44.19 | -2.410 | 1.133 | -2.369 |
| DenseNet121 | 84.8 | 32.13 | -2.665 | 1.029 | -2.504 |
| DenseNet169 | 85.0 | 37.51 | -2.494 | 1.051 | -2.418 |
| Densenet201 | 86.0 | 39.75 | -2.470 | 1.061 | -2.305 |
| MobileNet V2 | 80.8 | 30.36 | -2.800 | 1.039 | -2.653 |
| MNasNet | 83.9 | 32.05 | -2.732 | 1.051 | -2.643 |
| Pearson Corr | - | 0.815 | 0.513 | 0.698 | 0.705 |
| Weighted Tau | - | 0.775 | 0.659 | 0.790 | 0.654 |
#### Model Ranking Benchmark on DTD
| Model | Finetuned Acc | HScore | LEEP | LogME | NCE |
|--------------|---------------|--------|--------|-------|-------|
| GoogleNet | 73.6 | 34.61 | -2.333 | 0.682 | 0.682 |
| Inception V3 | 77.2 | 57.17 | -2.135 | 0.691 | 0.691 |
| ResNet50 | 75.2 | 78.26 | -1.985 | 0.695 | 0.695 |
| ResNet101 | 76.2 | 117.23 | -1.974 | 0.689 | 0.689 |
| ResNet152 | 75.4 | 32.30 | -1.924 | 0.698 | 0.698 |
| DenseNet121 | 74.9 | 35.23 | -2.001 | 0.670 | 0.670 |
| DenseNet169 | 74.8 | 43.36 | -1.817 | 0.686 | 0.686 |
| Densenet201 | 74.5 | 45.96 | -1.926 | 0.689 | 0.689 |
| MobileNet V2 | 72.9 | 37.99 | -2.098 | 0.664 | 0.664 |
| MNasNet | 72.8 | 38.03 | -2.033 | 0.679 | 0.679 |
| Pearson Corr | - | 0.532 | 0.217 | 0.617 | 0.471 |
| Weighted Tau | - | 0.416 | -0.004 | 0.550 | 0.083 |
#### Model Ranking Benchmark on OxfordIIITPets
| Model | Finetuned Acc | HScore | LEEP | LogME | NCE |
|--------------|---------------|--------|--------|-------|--------|
| GoogleNet | 91.9 | 28.02 | -1.064 | 0.854 | -0.815 |
| Inception V3 | 93.5 | 33.29 | -0.888 | 1.119 | -0.711 |
| ResNet50 | 92.5 | 32.55 | -0.805 | 0.952 | -0.721 |
| ResNet101 | 94.0 | 32.76 | -0.769 | 0.985 | -0.717 |
| ResNet152 | 94.5 | 32.86 | -0.732 | 1.009 | -0.679 |
| DenseNet121 | 92.9 | 27.09 | -0.837 | 0.797 | -0.753 |
| DenseNet169 | 93.1 | 30.09 | -0.779 | 0.829 | -0.699 |
| Densenet201 | 92.8 | 31.25 | -0.810 | 0.860 | -0.716 |
| MobileNet V2 | 90.5 | 27.83 | -0.902 | 0.765 | -0.822 |
| MNasNet | 89.4 | 27.95 | -0.854 | 0.785 | -0.812 |
| Pearson Corr | - | 0.427 | -0.127 | 0.589 | 0.501 |
| Weighted Tau | - | 0.425 | -0.143 | 0.502 | 0.119 |
#### Model Ranking Benchmark on StanfordCars
| Model | Finetuned Acc | HScore | LEEP | LogME | NCE |
|--------------|---------------|--------|--------|-------|--------|
| GoogleNet | 91.0 | 41.47 | -4.612 | 1.246 | -4.312 |
| Inception V3 | 92.3 | 73.68 | -4.268 | 1.259 | -4.110 |
| ResNet50 | 91.7 | 72.94 | -4.366 | 1.253 | -4.221 |
| ResNet101 | 91.7 | 73.98 | -4.281 | 1.255 | -4.218 |
| ResNet152 | 92.0 | 76.17 | -4.215 | 1.260 | -4.142 |
| DenseNet121 | 91.5 | 45.82 | -4.437 | 1.249 | -4.271 |
| DenseNet169 | 91.5 | 63.40 | -4.286 | 1.252 | -4.175 |
| Densenet201 | 91.0 | 70.50 | -4.319 | 1.251 | -4.151 |
| MobileNet V2 | 91.0 | 51.12 | -4.463 | 1.250 | -4.306 |
| MNasNet | 88.5 | 51.91 | -4.423 | 1.254 | -4.338 |
| Pearson Corr | - | 0.503 | 0.433 | 0.274 | 0.695 |
| Weighted Tau | - | 0.638 | 0.703 | 0.654 | 0.750 |
#### Model Ranking Benchmark on SUN397
| Model | Finetuned Acc | HScore | LEEP | LogME | NCE |
|--------------|---------------|--------|--------|-------|--------|
| GoogleNet | 62.0 | 71.35 | -3.744 | 1.621 | -3.055 |
| Inception V3 | 65.7 | 114.21 | -3.372 | 1.648 | -2.844 |
| ResNet50 | 64.7 | 110.39 | -3.198 | 1.638 | -2.894 |
| ResNet101 | 64.8 | 113.63 | -3.103 | 1.642 | -2.837 |
| ResNet152 | 66.0 | 116.51 | -3.056 | 1.646 | -2.822 |
| DenseNet121 | 62.3 | 72.16 | -3.311 | 1.614 | -2.945 |
| DenseNet169 | 63.0 | 95.80 | -3.165 | 1.623 | -2.903 |
| Densenet201 | 64.7 | 103.09 | -3.205 | 1.624 | -2.896 |
| MobileNet V2 | 60.5 | 75.90 | -3.338 | 1.617 | -2.968 |
| MNasNet | 60.7 | 80.91 | -3.234 | 1.625 | -2.933 |
| Pearson Corr | - | 0.913 | 0.428 | 0.824 | 0.782 |
| Weighted Tau | - | 0.918 | 0.581 | 0.748 | 0.873 |
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{bao_information-theoretic_2019,
title = {An Information-Theoretic Approach to Transferability in Task Transfer Learning},
booktitle = {ICIP},
author = {Bao, Yajie and Li, Yang and Huang, Shao-Lun and Zhang, Lin and Zheng, Lizhong and Zamir, Amir and Guibas, Leonidas},
year = {2019}
}
@inproceedings{nguyen_leep:_2020,
title = {LEEP: A New Measure to Evaluate Transferability of Learned Representations},
booktitle = {ICML},
author = {Nguyen, Cuong and Hassner, Tal and Seeger, Matthias and Archambeau, Cedric},
year = {2020}
}
@inproceedings{you_logme:_2021,
title = {LogME: Practical Assessment of Pre-trained Models for Transfer Learning},
booktitle = {ICML},
author = {You, Kaichao and Liu, Yong and Wang, Jianmin and Long, Mingsheng},
year = {2021}
}
@inproceedings{tran_transferability_2019,
title = {Transferability and hardness of supervised classification tasks},
booktitle = {ICCV},
author = {Tran, Anh T. and Nguyen, Cuong V. and Hassner, Tal},
year = {2019}
}
```
================================================
FILE: examples/model_selection/hscore.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import os
import sys
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
sys.path.append('../..')
from tllib.ranking import h_score
sys.path.append('.')
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
logger = utils.Logger(args.data, args.arch, 'results_hscore')
print(args)
print(f'Calc Transferabilities of {args.arch} on {args.data}')
try:
features = np.load(os.path.join(logger.get_save_dir(), 'features.npy'))
predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy'))
targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy'))
print('Loaded extracted features')
except:
print('Conducting feature extraction')
data_transform = utils.get_transform(resizing=args.resizing)
print("data_transform: ", data_transform)
model = utils.get_model(args.arch, args.pretrained).to(device)
score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate,
args.num_samples_per_classes)
score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
pin_memory=True)
print(f'Using {len(score_dataset)} samples for ranking')
features, predictions, targets = utils.forwarding_dataset(score_loader, model,
layer=eval(f'model.{args.layer}'), device=device)
if args.save_features:
np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features)
np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions)
np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets)
print('Conducting transferability calculation')
result = h_score(features, targets)
logger.write(
f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\n')
print(f'Results saved in {logger.get_result_dir()}')
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Ranking pre-trained models with HScore')
# dataset
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N', help='mini-batch size (default: 48)')
parser.add_argument('--resizing', default='res.', type=str)
# model
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='model to be ranked: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('-l', '--layer', default='fc',
help='before which layer features are extracted')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument("--save_features", action='store_true',
help="whether to save extracted features")
args = parser.parse_args()
main(args)
================================================
FILE: examples/model_selection/hscore.sh
================================================
#!/usr/bin/env bash
# Ranking Pre-trained Model
# ======================================================================================================================
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# FGVCAircraft
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# Caltech101
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# Oxford-IIIT
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features
================================================
FILE: examples/model_selection/leep.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import os
import sys
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
sys.path.append('../..')
from tllib.ranking import log_expected_empirical_prediction as leep
sys.path.append('.')
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
logger = utils.Logger(args.data, args.arch, 'results_leep')
print(args)
print(f'Calc Transferabilities of {args.arch} on {args.data}')
try:
features = np.load(os.path.join(logger.get_save_dir(), 'features.npy'))
predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy'))
targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy'))
print('Loaded extracted features')
except:
print('Conducting feature extraction')
data_transform = utils.get_transform(resizing=args.resizing)
print("data_transform: ", data_transform)
model = utils.get_model(args.arch, args.pretrained).to(device)
score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate,
args.num_samples_per_classes)
score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
pin_memory=True)
print(f'Using {len(score_dataset)} samples for ranking')
features, predictions, targets = utils.forwarding_dataset(score_loader, model,
layer=eval(f'model.{args.layer}'), device=device)
if args.save_features:
np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features)
np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions)
np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets)
print('Conducting transferability calculation')
result = leep(predictions, targets)
logger.write(
f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\n')
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Ranking pre-trained models with LEEP (Log Expected Empirical Prediction)')
# dataset
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N', help='mini-batch size (default: 48)')
parser.add_argument('--resizing', default='res.', type=str)
# model
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='model to be ranked: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('-l', '--layer', default='fc',
help='before which layer features are extracted')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument("--save_features", action='store_true',
help="whether to save extracted features")
args = parser.parse_args()
main(args)
================================================
FILE: examples/model_selection/leep.sh
================================================
#!/usr/bin/env bash
# Ranking Pre-trained Model
# ======================================================================================================================
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# FGVCAircraft
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# Caltech101
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# Oxford-IIIT
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features
================================================
FILE: examples/model_selection/logme.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import os
import sys
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
sys.path.append('../..')
from tllib.ranking import log_maximum_evidence as logme
sys.path.append('.')
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
logger = utils.Logger(args.data, args.arch, 'results_logme')
print(args)
print(f'Calc Transferabilities of {args.arch} on {args.data}')
try:
features = np.load(os.path.join(logger.get_save_dir(), 'features.npy'))
predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy'))
targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy'))
print('Loaded extracted features')
except:
print('Conducting feature extraction')
data_transform = utils.get_transform(resizing=args.resizing)
print("data_transform: ", data_transform)
model = utils.get_model(args.arch, args.pretrained).to(device)
score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate,
args.num_samples_per_classes)
score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
pin_memory=True)
print(f'Using {len(score_dataset)} samples for ranking')
features, predictions, targets = utils.forwarding_dataset(score_loader, model,
layer=eval(f'model.{args.layer}'), device=device)
if args.save_features:
np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features)
np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions)
np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets)
print('Conducting transferability calculation')
result = logme(features, targets)
logger.write(
f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\n')
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Ranking pre-trained models with LogME (Log Maximum Evidence)')
# dataset
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N', help='mini-batch size (default: 48)')
parser.add_argument('--resizing', default='res.', type=str)
# model
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='model to be ranked: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('-l', '--layer', default='fc',
help='before which layer features are extracted')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument("--save_features", action='store_true',
help="whether to save extracted features")
args = parser.parse_args()
main(args)
================================================
FILE: examples/model_selection/logme.sh
================================================
#!/usr/bin/env bash
# Ranking Pre-trained Model
# ======================================================================================================================
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# FGVCAircraft
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# Caltech101
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# Oxford-IIIT
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features
================================================
FILE: examples/model_selection/nce.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import os
import sys
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
sys.path.append('../..')
from tllib.ranking import negative_conditional_entropy as nce
sys.path.append('.')
import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args):
logger = utils.Logger(args.data, args.arch, 'results_nce')
print(args)
print(f'Calc Transferabilities of {args.arch} on {args.data}')
try:
features = np.load(os.path.join(logger.get_save_dir(), 'features.npy'))
predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy'))
targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy'))
print('Loaded extracted features')
except:
print('Conducting feature extraction')
data_transform = utils.get_transform(resizing=args.resizing)
print("data_transform: ", data_transform)
model = utils.get_model(args.arch, args.pretrained).to(device)
score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate,
args.num_samples_per_classes)
score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers,
pin_memory=True)
print(f'Using {len(score_dataset)} samples for ranking')
features, predictions, targets = utils.forwarding_dataset(score_loader, model,
layer=eval(f'model.{args.layer}'), device=device)
if args.save_features:
np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features)
np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions)
np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets)
print('Conducting transferability calculation')
result = nce(np.argmax(predictions, axis=1), targets)
logger.write(
f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\n')
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Ranking pre-trained models with NCE (Negative Conditional Entropy)')
# dataset
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N', help='mini-batch size (default: 48)')
parser.add_argument('--resizing', default='res.', type=str)
# model
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='model to be ranked: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('-l', '--layer', default='fc',
help='before which layer features are extracted')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument("--save_features", action='store_true',
help="whether to save extracted features")
args = parser.parse_args()
main(args)
================================================
FILE: examples/model_selection/nce.sh
================================================
#!/usr/bin/env bash
# Ranking Pre-trained Model
# ======================================================================================================================
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# FGVCAircraft
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# Caltech101
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# Oxford-IIIT
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features
CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features
================================================
FILE: examples/model_selection/requirements.txt
================================================
timm
numba
================================================
FILE: examples/model_selection/utils.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import random
import sys, os
import torch
import timm
from torch.utils.data import Subset
import torchvision.transforms as T
import torch.nn.functional as F
import torchvision.models as models
sys.path.append('../../..')
import tllib.vision.datasets as datasets
class Logger(object):
"""Writes stream output to external text file.
Args:
filename (str): the file to write stream output
stream: the stream to read from. Default: sys.stdout
"""
def __init__(self, data_name, model_name, metric_name, stream=sys.stdout):
self.terminal = stream
self.save_dir = os.path.join(data_name, model_name) # save intermediate features/outputs
self.result_dir = os.path.join(data_name, f'{metric_name}.txt') # save ranking results
os.makedirs(self.save_dir, exist_ok=True)
self.log = open(self.result_dir, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.flush()
def get_save_dir(self):
return self.save_dir
def get_result_dir(self):
return self.result_dir
def flush(self):
self.terminal.flush()
self.log.flush()
def close(self):
self.terminal.close()
self.log.close()
def get_model_names():
return sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
) + timm.list_models()
def forwarding_dataset(score_loader, model, layer, device):
"""
A forward forcasting on full dataset
:params score_loader: the dataloader for scoring transferability
:params model: the model for scoring transferability
:params layer: before which layer features are extracted, for registering hooks
returns
features: extracted features of model
prediction: probability outputs of model
targets: ground-truth labels of dataset
"""
features = []
outputs = []
targets = []
def hook_fn_forward(module, input, output):
features.append(input[0].detach().cpu())
outputs.append(output.detach().cpu())
forward_hook = layer.register_forward_hook(hook_fn_forward)
model.eval()
with torch.no_grad():
for _, (data, target) in enumerate(score_loader):
targets.append(target)
data = data.to(device)
_ = model(data)
forward_hook.remove()
features = torch.cat([x for x in features]).numpy()
outputs = torch.cat([x for x in outputs])
predictions = F.softmax(outputs, dim=-1).numpy()
targets = torch.cat([x for x in targets]).numpy()
return features, predictions, targets
def get_model(model_name, pretrained=True, pretrained_checkpoint=None):
if model_name in get_model_names():
# load models from common.vision.models
backbone = models.__dict__[model_name](pretrained=pretrained)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=pretrained)
if pretrained_checkpoint:
print("=> loading pre-trained model from '{}'".format(pretrained_checkpoint))
pretrained_dict = torch.load(pretrained_checkpoint)
backbone.load_state_dict(pretrained_dict, strict=False)
return backbone
def get_dataset(dataset_name, root, transform, sample_rate=100, num_samples_per_classes=None, split='train'):
"""
When sample_rate < 100, e.g. sample_rate = 50, use 50% data to train the model.
Otherwise,
if num_samples_per_classes is not None, e.g. 5, then sample 5 images for each class, and use them to train the model;
otherwise, keep all the data.
"""
dataset = datasets.__dict__[dataset_name]
if sample_rate < 100:
score_dataset = dataset(root=root, split=split, sample_rate=sample_rate, download=True, transform=transform)
num_classes = len(score_dataset.classes)
else:
score_dataset = dataset(root=root, split=split, download=True, transform=transform)
num_classes = len(score_dataset.classes)
if num_samples_per_classes is not None:
samples = list(range(len(score_dataset)))
random.shuffle(samples)
samples_len = min(num_samples_per_classes * num_classes, len(score_dataset))
print("Origin dataset:", len(score_dataset), "Sampled dataset:", samples_len, "Ratio:",
float(samples_len) / len(score_dataset))
dataset = Subset(score_dataset, samples[:samples_len])
return score_dataset, num_classes
def get_transform(resizing='res.'):
"""
resizing mode:
- default: resize the image to 256 and take the center crop of size 224;
– res.: resize the image to 224
– res.|crop: resize the image such that the smaller side is of size 256 and
then take a central crop of size 224.
"""
if resizing == 'default':
transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
])
elif resizing == 'res.':
transform = T.Resize((224, 224))
elif resizing == 'res.299':
transform = T.Resize((299, 299))
elif resizing == 'res.|crop':
transform = T.Compose([
T.Resize((256, 256)),
T.CenterCrop(224),
])
else:
raise NotImplementedError(resizing)
return T.Compose([
transform,
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
================================================
FILE: examples/semi_supervised_learning/image_classification/README.md
================================================
# Semi-Supervised Learning for Image Classification
## Installation
It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results.
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You
also need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
Following datasets can be downloaded automatically:
- [FOOD-101](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/)
- [CIFAR10](http://www.cs.utoronto.ca/~kriz/cifar.html)
- [CIFAR100](http://www.cs.utoronto.ca/~kriz/cifar.html)
- [CUB200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
- [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)
- [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)
- [SUN397](https://vision.princeton.edu/projects/2010/SUN/)
- [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html)
- [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/)
- [OxfordFlowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/)
- [Caltech101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/)
## Supported Methods
Supported methods include:
- [Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks (Pseudo Label, ICML 2013)](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf)
- [Temporal Ensembling for Semi-Supervised Learning (Pi Model, ICLR 2017)](https://arxiv.org/abs/1610.02242)
- [Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (Mean Teacher, NIPS 2017)](https://arxiv.org/abs/1703.01780)
- [Self-Training With Noisy Student Improves ImageNet Classification (Noisy Student, CVPR 2020)](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xie_Self-Training_With_Noisy_Student_Improves_ImageNet_Classification_CVPR_2020_paper.pdf)
- [Unsupervised Data Augmentation for Consistency Training (UDA, NIPS 2020)](https://arxiv.org/pdf/1904.12848v4.pdf)
- [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence (FixMatch, NIPS 2020)](https://arxiv.org/abs/2001.07685)
- [Self-Tuning for Data-Efficient Deep Learning (Self-Tuning, ICML 2021)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf)
- [FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling (FlexMatch, NIPS 2021)](https://arxiv.org/abs/2110.08263)
- [Debiased Learning From Naturally Imbalanced Pseudo-Labels (DebiasMatch, CVPR 2022)](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_Debiased_Learning_From_Naturally_Imbalanced_Pseudo-Labels_CVPR_2022_paper.pdf)
- [Debiased Self-Training for Semi-Supervised Learning (DST)](https://arxiv.org/abs/2202.07136)
## Usage
### Semi-supervised learning with supervised pre-trained model
The shell files give the script to train with supervised pre-trained model with specified hyper-parameters. For example,
if you want to train UDA on CIFAR100, use the following script
```shell script
# Semi-supervised learning on CIFAR100 (ResNet50, 400labels).
# Assume you have put the datasets under the path `data/cifar100`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cifar100_4_labels_per_class
```
Following common practice in semi-supervised learning, we select a class-balanced subset as the labeled dataset and
treat other samples as unlabeled data. In the above command, `num-samples-per-class` specifies how many labeled samples
for each class. Note that the labeled subset is **deterministic with the same random seed**. Hence, if you want to
compare different algorithms with the same labeled subset, you can simply pass in the same random seed.
### Semi-supervised learning with unsupervised pre-trained model
Take MoCo as an example.
1. Download MoCo pretrained checkpoints from https://github.com/facebookresearch/moco
2. Convert the format of the MoCo checkpoints to the standard format of pytorch
```shell
mkdir checkpoints
python convert_moco_to_pretrained.py checkpoints/moco_v2_800ep_pretrain.pth.tar checkpoints/moco_v2_800ep_backbone.pth checkpoints/moco_v2_800ep_fc.pth
```
3. Start training
```shell
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_4_labels_per_class
```
## Experiment and Results
**Notations**
- ``Avg`` is the accuracy reported by `TLlib`.
- ``ERM`` refers to the model trained with only labeled data.
- ``Oracle`` refers to the model trained using all data as labeled data.
Below are the results of implemented methods. Other than _Oracle_, we randomly sample 4 labels per category.
### ImageNet Supervised Pre-training (ResNet-50)
| Methods | Food101 | CIFAR10 | CIFAR100 | CUB200 | Aircraft | Cars | SUN397 | DTD | Pets | Flowers | Caltech | Avg |
|--------------|---------|---------|----------|--------|----------|------|--------|------|------|---------|---------|------|
| ERM | 33.6 | 59.4 | 47.9 | 48.6 | 29.0 | 37.1 | 40.9 | 50.5 | 82.2 | 87.6 | 82.2 | 54.5 |
| Pseudo Label | 36.9 | 62.8 | 52.5 | 54.9 | 30.4 | 40.4 | 41.7 | 54.1 | 89.6 | 93.5 | 85.1 | 58.4 |
| Pi Model | 34.2 | 66.9 | 48.5 | 47.9 | 26.7 | 37.4 | 40.9 | 51.9 | 83.5 | 92.0 | 82.2 | 55.6 |
| Mean Teacher | 40.4 | 78.1 | 58.5 | 52.8 | 32.0 | 45.6 | 40.2 | 53.8 | 86.8 | 92.8 | 83.7 | 60.4 |
| UDA | 41.9 | 73.0 | 59.8 | 55.4 | 33.5 | 42.7 | 42.1 | 49.7 | 88.0 | 93.4 | 85.3 | 60.4 |
| FixMatch | 36.2 | 74.5 | 58.0 | 52.6 | 27.1 | 44.8 | 40.8 | 50.2 | 87.8 | 93.6 | 83.2 | 59.0 |
| Self Tuning | 41.4 | 70.9 | 57.2 | 60.5 | 37.0 | 59.8 | 43.5 | 51.7 | 88.4 | 93.5 | 89.1 | 63.0 |
| FlexMatch | 48.1 | 94.2 | 69.2 | 65.1 | 38.0 | 55.3 | 50.2 | 55.6 | 91.5 | 94.6 | 89.4 | 68.3 |
| DebiasMatch | 57.1 | 92.4 | 69.0 | 66.2 | 41.5 | 65.4 | 48.3 | 54.2 | 90.2 | 95.4 | 89.3 | 69.9 |
| DST | 58.1 | 93.5 | 67.8 | 68.6 | 44.9 | 68.6 | 47.0 | 56.3 | 91.5 | 95.1 | 90.3 | 71.1 |
| Oracle | 85.5 | 97.5 | 86.3 | 81.1 | 85.1 | 91.1 | 64.1 | 68.8 | 93.2 | 98.1 | 92.6 | 85.8 |
### ImageNet Unsupervised Pre-training (ResNet-50, MoCo v2)
| Methods | Food101 | CIFAR10 | CIFAR100 | CUB200 | Aircraft | Cars | SUN397 | DTD | Pets | Flowers | Caltech | Avg |
|--------------|---------|---------|----------|--------|----------|------|--------|------|------|---------|---------|------|
| ERM | 33.5 | 63.0 | 50.8 | 39.4 | 28.1 | 40.3 | 40.7 | 53.7 | 65.4 | 87.5 | 82.8 | 53.2 |
| Pseudo Label | 33.6 | 71.9 | 53.8 | 42.7 | 30.9 | 51.2 | 41.2 | 55.2 | 69.3 | 94.2 | 86.2 | 57.3 |
| Pi Model | 32.7 | 77.9 | 50.9 | 33.6 | 27.2 | 34.4 | 41.1 | 54.9 | 66.7 | 91.4 | 84.1 | 54.1 |
| Mean Teacher | 36.8 | 79.0 | 56.7 | 43.0 | 33.0 | 53.9 | 39.5 | 54.5 | 67.8 | 92.7 | 83.3 | 58.2 |
| UDA | 39.5 | 91.3 | 60.0 | 41.9 | 36.2 | 39.7 | 41.7 | 51.5 | 71.0 | 93.7 | 86.5 | 59.4 |
| FixMatch | 44.3 | 86.1 | 58.0 | 42.7 | 38.0 | 55.4 | 42.4 | 53.1 | 67.9 | 95.2 | 83.4 | 60.6 |
| Self Tuning | 34.0 | 63.6 | 51.7 | 43.3 | 32.2 | 50.2 | 40.7 | 52.7 | 68.2 | 91.8 | 87.7 | 56.0 |
| FlexMatch | 50.2 | 96.6 | 69.2 | 49.4 | 41.3 | 62.5 | 47.2 | 54.5 | 72.4 | 94.8 | 89.4 | 66.1 |
| DebiasMatch | 54.2 | 95.5 | 68.1 | 49.1 | 40.9 | 73.0 | 47.6 | 54.4 | 76.6 | 95.5 | 88.7 | 67.6 |
| DST | 57.1 | 95.0 | 68.2 | 53.6 | 47.7 | 72.0 | 46.8 | 56.0 | 76.3 | 95.6 | 90.1 | 68.9 |
| Oracle | 87.0 | 98.2 | 87.9 | 80.6 | 88.7 | 92.7 | 63.9 | 73.8 | 90.6 | 97.8 | 93.1 | 86.8 |
## TODO
1. support multi-gpu training
2. add training from scratch code and results
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{pseudo_label,
title={Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks},
author={Lee, Dong-Hyun and others},
booktitle={ICML},
year={2013}
}
@inproceedings{pi_model,
title={Temporal ensembling for semi-supervised learning},
author={Laine, Samuli and Aila, Timo},
booktitle={ICLR},
year={2017}
}
@inproceedings{mean_teacher,
title={Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results},
author={Tarvainen, Antti and Valpola, Harri},
booktitle={NIPS},
year={2017}
}
@inproceedings{noisy_student,
title={Self-training with noisy student improves imagenet classification},
author={Xie, Qizhe and Luong, Minh-Thang and Hovy, Eduard and Le, Quoc V},
booktitle={CVPR},
year={2020}
}
@inproceedings{UDA,
title={Unsupervised data augmentation for consistency training},
author={Xie, Qizhe and Dai, Zihang and Hovy, Eduard and Luong, Thang and Le, Quoc},
booktitle={NIPS},
year={2020}
}
@inproceedings{FixMatch,
title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence},
author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang},
booktitle={NIPS},
year={2020}
}
@inproceedings{SelfTuning,
title={Self-tuning for data-efficient deep learning},
author={Wang, Ximei and Gao, Jinghan and Long, Mingsheng and Wang, Jianmin},
booktitle={ICML},
year={2021}
}
@inproceedings{FlexMatch,
title={Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling},
author={Zhang, Bowen and Wang, Yidong and Hou, Wenxin and Wu, Hao and Wang, Jindong and Okumura, Manabu and Shinozaki, Takahiro},
booktitle={NeurIPS},
year={2021}
}
@inproceedings{DebiasMatch,
title={Debiased Learning from Naturally Imbalanced Pseudo-Labels},
author={Wang, Xudong and Wu, Zhirong and Lian, Long and Yu, Stella X},
booktitle={CVPR},
year={2022}
}
@article{DST,
title={Debiased Self-Training for Semi-Supervised Learning},
author={Chen, Baixu and Jiang, Junguang and Wang, Ximei and Wang, Jianmin and Long, Mingsheng},
journal={arXiv preprint arXiv:2202.07136},
year={2022}
}
```
================================================
FILE: examples/semi_supervised_learning/image_classification/convert_moco_to_pretrained.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import sys
import torch
if __name__ == "__main__":
input = sys.argv[1]
obj = torch.load(input, map_location="cpu")
obj = obj["state_dict"]
newmodel = {}
fc = {}
for k, v in obj.items():
if not k.startswith("module.encoder_q."):
continue
old_k = k
k = k.replace("module.encoder_q.", "")
if k.startswith("fc"):
print(k)
fc[k] = v
else:
newmodel[k] = v
with open(sys.argv[2], "wb") as f:
torch.save(newmodel, f)
with open(sys.argv[3], "wb") as f:
torch.save(fc, f)
================================================
FILE: examples/semi_supervised_learning/image_classification/debiasmatch.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# initialize q_hat
q_hat = (torch.ones(num_classes) / num_classes).to(device)
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, q_hat, epoch, args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,
lr_scheduler: LambdaLR, q_hat, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
self_training_losses = AverageMeter('Self Training Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')
pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,
pseudo_label_ratios],
prefix="Epoch: [{}]".format(epoch))
self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
(x_u, x_u_strong), labels_u = next(unlabeled_train_iter)
x_u = x_u.to(device)
x_u_strong = x_u_strong.to(device)
labels_u = labels_u.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# cross entropy loss
y_l = model(x_l)
y_l_strong = model(x_l_strong)
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# self training loss
with torch.no_grad():
y_u = model(x_u)
y_u_strong = model(x_u_strong)
# update q_hat
q = torch.softmax(y_u, dim=1).mean(dim=0)
q_hat = args.momentum * q_hat + (1 - args.momentum) * q
self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong + args.tau * torch.log(q_hat),
y_u - args.tau * torch.log(q_hat))
self_training_loss = args.trade_off_self_training * self_training_loss
self_training_loss.backward()
# measure accuracy and record loss
loss = cls_loss + self_training_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
self_training_losses.update(self_training_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# ratio of pseudo labels
n_pseudo_labels = mask.sum()
ratio = n_pseudo_labels / batch_size
pseudo_label_ratios.update(ratio.item() * 100, batch_size)
# accuracy of pseudo labels
if n_pseudo_labels > 0:
pseudo_labels = pseudo_labels * mask - (1 - mask)
n_correct = (pseudo_labels == labels_u).float().sum()
pseudo_label_acc = n_correct / n_pseudo_labels * 100
pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='DebiasMatch for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--momentum', default=0.999, type=float,
help='momentum coefficient for updating q_hat (default: 0.999)')
parser.add_argument('--tau', default=1, type=float,
help='debiased strength (default: 1)')
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-self-training', default=1, type=float,
help='the trade-off hyper-parameter of self training loss')
parser.add_argument('--threshold', default=0.95, type=float,
help='confidence threshold')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run (default: 90)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='debiasmatch', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/debiasmatch.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.8 --tau 1 --seed 0 --log logs/debiasmatch/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --threshold 0.9 --tau 1 --seed 0 --log logs/debiasmatch/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.7 --tau 1 --seed 0 --log logs/debiasmatch/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 7 --tau 3 --seed 0 --log logs/debiasmatch/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/dst.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss
from tllib.self_training.dst import ImageClassifier, WorstCaseEstimationLoss
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, width=args.width,
pool_layer=pool_layer, finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
self_training_losses = AverageMeter('Self Training Loss', ':3.2f')
wce_losses = AverageMeter('Worst Case Estimation Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')
pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, self_training_losses, wce_losses, cls_accs, pseudo_label_accs,
pseudo_label_ratios],
prefix="Epoch: [{}]".format(epoch))
self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)
worst_case_estimation_criterion = WorstCaseEstimationLoss(args.eta_prime).to(device)
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
(x_u, x_u_strong), labels_u = next(unlabeled_train_iter)
x_u = x_u.to(device)
x_u_strong = x_u_strong.to(device)
labels_u = labels_u.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# ==============================================================================================================
# cross entropy loss (strong augment)
# ==============================================================================================================
y_l_strong, _, _ = model(x_l_strong)
cls_loss_strong = args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss_strong.backward()
x = torch.cat((x_l, x_u), dim=0)
outputs, outputs_adv, _ = model(x)
y_l, y_u = outputs.chunk(2, dim=0)
y_l_adv, y_u_adv = outputs_adv.chunk(2, dim=0)
# ==============================================================================================================
# cross entropy loss (weak augment)
# ==============================================================================================================
cls_loss_weak = F.cross_entropy(y_l, labels_l)
# ==============================================================================================================
# worst case estimation loss
# ==============================================================================================================
wce_loss = args.eta * worst_case_estimation_criterion(y_l, y_l_adv, y_u, y_u_adv)
(cls_loss_weak + wce_loss).backward()
# ==============================================================================================================
# self training loss
# ==============================================================================================================
_, _, y_u_strong = model(x_u_strong)
self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong, y_u)
self_training_loss = args.trade_off_self_training * self_training_loss
self_training_loss.backward()
# measure accuracy and record loss
cls_loss = cls_loss_strong + cls_loss_weak
cls_losses.update(cls_loss.item(), batch_size)
loss = cls_loss + self_training_loss + wce_loss
losses.update(loss.item(), batch_size)
wce_losses.update(wce_loss.item(), batch_size)
self_training_losses.update(self_training_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# ratio of pseudo labels
n_pseudo_labels = mask.sum()
ratio = n_pseudo_labels / batch_size
pseudo_label_ratios.update(ratio.item() * 100, batch_size)
# accuracy of pseudo labels
if n_pseudo_labels > 0:
pseudo_labels = pseudo_labels * mask - (1 - mask)
n_correct = (pseudo_labels == labels_u).float().sum()
pseudo_label_acc = n_correct / n_pseudo_labels * 100
pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
model.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Debiased Self-Training for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--width', default=2048, type=int,
help='width of the pseudo head and the worst-case estimation head')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-self-training', default=1, type=float,
help='the trade-off hyper-parameter of self training loss')
parser.add_argument('--eta', default=1, type=float,
help='the trade-off hyper-parameter of adversarial loss')
parser.add_argument('--eta-prime', default=2, type=float,
help="the trade-off hyper-parameter between adversarial loss on labeled data "
"and that on unlabeled data")
parser.add_argument('--threshold', default=0.7, type=float,
help='confidence threshold')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0002, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run (default: 90)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='dst', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/dst.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python dst.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.8 --trade-off-self-training 1 --eta-prime 2 \
--seed 0 --log logs/dst/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python dst.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \
--seed 0 --log logs/dst/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python dst.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \
--seed 0 --log logs/dst/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python dst.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.95 --trade-off-self-training 0.3 --eta-prime 2 \
--seed 0 --log logs/dst/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python dst.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \
--seed 0 --log logs/dst/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python dst.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \
--seed 0 --log logs/dst/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python dst.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \
--seed 0 --log logs/dst/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python dst.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.95 --trade-off-self-training 1 --eta-prime 2 \
--seed 0 --log logs/dst/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python dst.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.9 --trade-off-self-training 0.3 --eta-prime 2 \
--seed 0 --log logs/dst/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python dst.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.9 --trade-off-self-training 0.3 --eta-prime 1 \
--seed 0 --log logs/dst/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python dst.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 4 \
--seed 0 --log logs/dst/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python dst.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \
--seed 0 --log logs/dst_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python dst.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 1 --eta-prime 2 \
--seed 0 --log logs/dst_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python dst.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \
--seed 0 --log logs/dst_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python dst.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \
--seed 0 --log logs/dst_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python dst.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \
--seed 0 --log logs/dst_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python dst.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \
--seed 0 --log logs/dst_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python dst.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 0.3 --eta-prime 2 \
--seed 0 --log logs/dst_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python dst.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 0.1 --eta-prime 3 \
--seed 0 --log logs/dst_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python dst.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 0.1 --eta-prime 1 \
--seed 0 --log logs/dst_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python dst.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 1 --eta-prime 1 \
--seed 0 --log logs/dst_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python dst.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 0.1 --eta-prime 1 \
--seed 0 --log logs/dst_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/erm.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, ConcatDataset
import utils
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
train_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('train_transform: ', train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, train_transform,
val_transform,
seed=args.seed)
if args.oracle:
num_classes = labeled_train_dataset.num_classes
labeled_train_dataset = ConcatDataset([labeled_train_dataset, unlabeled_train_dataset])
labeled_train_dataset.num_classes = num_classes
print("labeled_dataset_size: ", len(labeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
utils.empirical_risk_minimization(labeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args, device)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Baseline for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
parser.add_argument('--oracle', action='store_true', default=False,
help='use all data as labeled data (oracle)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run (default: 20)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='baseline', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/erm.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/erm/food101_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 10 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/erm/food101_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --oracle -a resnet50 \
--lr 0.01 --finetune --epochs 80 --seed 0 --log logs/erm/food101_oracle
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/cifar10_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 10 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/cifar10_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --oracle -a resnet50 \
--lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/cifar10_oracle
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/erm/cifar100_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 10 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/erm/cifar100_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --oracle -a resnet50 \
--lr 0.01 --finetune --epochs 80 --seed 0 --log logs/erm/cifar100_oracle
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/erm/cub200_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 10 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/erm/cub200_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --oracle -a resnet50 \
--lr 0.003 --finetune --epochs 80 --seed 0 --log logs/erm/cub200_oracle
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/aircraft_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 10 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/aircraft_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --oracle -a resnet50 \
--lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/aircraft_oracle
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/car_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 10 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/car_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --oracle -a resnet50 \
--lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/car_oracle
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --seed 0 --log logs/erm/sun_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 10 -a resnet50 \
--lr 0.001 --finetune --seed 0 --log logs/erm/sun_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --oracle -a resnet50 \
--lr 0.001 --finetune --epochs 80 --seed 0 --log logs/erm/sun_oracle
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/dtd_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 10 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/dtd_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --oracle -a resnet50 \
--lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/dtd_oracle
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --seed 0 --log logs/erm/pets_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 10 -a resnet50 \
--lr 0.001 --finetune --seed 0 --log logs/erm/pets_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --oracle -a resnet50 \
--lr 0.001 --finetune --epochs 80 --seed 0 --log logs/erm/pets_oracle
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/flowers_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 10 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/erm/flowers_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --oracle -a resnet50 \
--lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/flowers_oracle
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/erm/caltech_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 10 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/erm/caltech_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --oracle -a resnet50 \
--lr 0.003 --finetune --epochs 80 --seed 0 --log logs/erm/caltech_oracle
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/food101_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/food101_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/food101_oracle
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar10_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar10_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cifar10_oracle
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cifar100_oracle
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cub200_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cub200_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cub200_oracle
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/aircraft_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/aircraft_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/aircraft_oracle
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/car_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/car_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/car_oracle
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/sun_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/sun_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/sun_oracle
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/dtd_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/dtd_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/dtd_oracle
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/pets_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/pets_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/pets_oracle
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/flowers_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/flowers_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/flowers_oracle
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/caltech_4_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 10 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/caltech_10_labels_per_class
CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --oracle -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/caltech_oracle
================================================
FILE: examples/semi_supervised_learning/image_classification/fixmatch.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
self_training_losses = AverageMeter('Self Training Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')
pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,
pseudo_label_ratios],
prefix="Epoch: [{}]".format(epoch))
self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
(x_u, x_u_strong), labels_u = next(unlabeled_train_iter)
x_u = x_u.to(device)
x_u_strong = x_u_strong.to(device)
labels_u = labels_u.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# cross entropy loss
y_l = model(x_l)
y_l_strong = model(x_l_strong)
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# self training loss
with torch.no_grad():
y_u = model(x_u)
y_u_strong = model(x_u_strong)
self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong, y_u)
self_training_loss = args.trade_off_self_training * self_training_loss
self_training_loss.backward()
# measure accuracy and record loss
loss = cls_loss + self_training_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
self_training_losses.update(self_training_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# ratio of pseudo labels
n_pseudo_labels = mask.sum()
ratio = n_pseudo_labels / batch_size
pseudo_label_ratios.update(ratio.item() * 100, batch_size)
# accuracy of pseudo labels
if n_pseudo_labels > 0:
pseudo_labels = pseudo_labels * mask - (1 - mask)
n_correct = (pseudo_labels == labels_u).float().sum()
pseudo_label_acc = n_correct / n_pseudo_labels * 100
pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='FixMatch for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-self-training', default=1, type=float,
help='the trade-off hyper-parameter of self training loss')
parser.add_argument('--threshold', default=0.95, type=float,
help='confidence threshold')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run (default: 60)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='fixmatch', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/fixmatch.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/fixmatch/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/fixmatch/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.9 --seed 0 --log logs/fixmatch/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/fixmatch_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/fixmatch_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/fixmatch_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/fixmatch_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/fixmatch_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/fixmatch_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/flexmatch.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.flexmatch import DynamicThresholdingModule
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
unlabeled_train_dataset = utils.convert_dataset(unlabeled_train_dataset)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
args.num_classes = num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# thresholding module with convex mapping function x / (2 - x)
thresholding_module = DynamicThresholdingModule(args.threshold, args.warmup, lambda x: x / (2 - x), num_classes,
len(unlabeled_train_dataset), device=device)
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, thresholding_module, classifier, optimizer, lr_scheduler, epoch,
args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator,
thresholding_module: DynamicThresholdingModule, model, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int,
args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
self_training_losses = AverageMeter('Self Training Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')
pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,
pseudo_label_ratios],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
idx_u, ((x_u, x_u_strong), labels_u) = next(unlabeled_train_iter)
idx_u = idx_u.to(device)
x_u = x_u.to(device)
x_u_strong = x_u_strong.to(device)
labels_u = labels_u.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# cross entropy loss
y_l = model(x_l)
y_l_strong = model(x_l_strong)
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# self training loss
with torch.no_grad():
y_u = model(x_u)
y_u_strong = model(x_u_strong)
confidence, pseudo_labels = torch.softmax(y_u, dim=1).max(dim=1)
dynamic_threshold = thresholding_module.get_threshold(pseudo_labels)
mask = (confidence > dynamic_threshold).float()
# mask used for updating learning status
selected_mask = (confidence > args.threshold).long()
thresholding_module.update(idx_u, selected_mask, pseudo_labels)
self_training_loss = args.trade_off_self_training * (
F.cross_entropy(y_u_strong, pseudo_labels, reduction='none') * mask).mean()
self_training_loss.backward()
# measure accuracy and record loss
loss = cls_loss + self_training_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
self_training_losses.update(self_training_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# ratio of pseudo labels
n_pseudo_labels = mask.sum()
ratio = n_pseudo_labels / batch_size
pseudo_label_ratios.update(ratio.item() * 100, batch_size)
# accuracy of pseudo labels
if n_pseudo_labels > 0:
pseudo_labels = pseudo_labels * mask - (1 - mask)
n_correct = (pseudo_labels == labels_u).float().sum()
pseudo_label_acc = n_correct / n_pseudo_labels * 100
pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='FlexMatch for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--warmup', default=False, type=bool)
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-self-training', default=1, type=float,
help='the trade-off hyper-parameter of self training loss')
parser.add_argument('--threshold', default=0.95, type=float,
help='confidence threshold')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run (default: 90)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='flexmatch', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/flexmatch.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.9 --seed 0 --log logs/flexmatch/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.9 --seed 0 --log logs/flexmatch/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/flexmatch_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/flexmatch_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/mean_teacher.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.pi_model import sigmoid_warm_up, L2ConsistencyLoss
from tllib.self_training.mean_teacher import update_bn, EMATeacher
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = MultipleApply([weak_augment, weak_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
teacher = EMATeacher(classifier, alpha=args.alpha)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, classifier, teacher, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, teacher,
optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
con_losses = AverageMeter('Con Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, con_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
consistency_criterion = L2ConsistencyLoss(reduction='sum').to(device)
# switch to train mode
model.train()
teacher.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
(x_u, x_u_teacher), _ = next(unlabeled_train_iter)
x_u = x_u.to(device)
x_u_teacher = x_u_teacher.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# cross entropy loss
y_l = model(x_l)
y_l_strong = model(x_l_strong)
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# consistency loss
with torch.no_grad():
y_u_teacher = teacher(x_u_teacher)
p_u_teacher = torch.softmax(y_u_teacher, dim=1)
y_u = model(x_u)
p_u = torch.softmax(y_u, dim=1)
con_loss = args.trade_off_con * sigmoid_warm_up(epoch, args.warm_up_epochs) * \
consistency_criterion(p_u, p_u_teacher)
con_loss.backward()
# measure accuracy and record loss
loss = cls_loss + con_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
con_losses.update(con_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# update teacher
global_step = epoch * args.iters_per_epoch + i + 1
teacher.set_alpha(min(args.alpha, 1 - 1 / global_step))
teacher.update()
update_bn(model, teacher.teacher)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Mean Teacher for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
parser.add_argument('--alpha', default=0.999, type=float,
help='ema decay factor')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-con', default=0.1, type=float,
help='the trade-off hyper-parameter of consistency loss')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0001, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=40, type=int, metavar='N',
help='number of total epochs to run (default: 40)')
parser.add_argument('--warm-up-epochs', default=10, type=int,
help='number of epochs to warm up (default: 10)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='mean_teacher', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/mean_teacher.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/mean_teacher/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.05 --finetune --seed 0 --log logs/mean_teacher/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/mean_teacher/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/mean_teacher/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/mean_teacher/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.05 --finetune --seed 0 --log logs/mean_teacher/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/mean_teacher/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/mean_teacher/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/mean_teacher/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/mean_teacher/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/mean_teacher/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/noisy_student.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import copy
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.vision.models.reid.loss import CrossEntropyLoss
from tllib.modules.classifier import Classifier
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ImageClassifier(Classifier):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
bottleneck[0].weight.data.normal_(0, 0.005)
bottleneck[0].bias.data.fill_(0.1)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
self.dropout = nn.Dropout(0.5)
self.as_teacher_model = False
def forward(self, x: torch.Tensor):
""""""
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
if not self.as_teacher_model:
f = self.dropout(f)
predictions = self.head(f)
return predictions
def calc_teacher_output(classifier_teacher: ImageClassifier, weak_augmented_unlabeled_dataset):
"""Compute outputs of the teacher network. Here, we use weak data augmentation and do not introduce an additional
dropout layer according to the Noisy Student paper `Self-Training With Noisy Student Improves ImageNet
Classification `_.
"""
data_loader = DataLoader(weak_augmented_unlabeled_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, drop_last=False)
batch_time = AverageMeter('Time', ':6.3f')
progress = ProgressMeter(
len(data_loader),
[batch_time],
prefix='Computing teacher output: ')
teacher_output = []
with torch.no_grad():
end = time.time()
for i, (images, _) in enumerate(data_loader):
images = images.to(device)
output = classifier_teacher(images)
teacher_output.append(output)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
teacher_output = torch.cat(teacher_output, dim=0)
return teacher_output
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('weak_augment (input transform for teacher model): ', weak_augment)
print('strong_augment (input transform for student model): ', strong_augment)
print('val_transform:', val_transform)
labeled_train_dataset, weak_augmented_unlabeled_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=weak_augment,
seed=args.seed)
_, strong_augmented_unlabeled_dataset, _ = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=strong_augment,
seed=args.seed)
strong_augmented_unlabeled_dataset = utils.convert_dataset(strong_augmented_unlabeled_dataset)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(weak_augmented_unlabeled_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(strong_augmented_unlabeled_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
if args.pretrained_teacher:
# load teacher model
classifier_teacher = copy.deepcopy(classifier)
checkpoint = torch.load(args.pretrained_teacher)
classifier_teacher.load_state_dict(checkpoint)
classifier_teacher.eval()
classifier_teacher.as_teacher_model = True
print('compute outputs of the teacher network')
teacher_output = calc_teacher_output(classifier_teacher, weak_augmented_unlabeled_dataset)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
if args.pretrained_teacher:
train(labeled_train_iter, unlabeled_train_iter, classifier, teacher_output, optimizer, lr_scheduler,
epoch, args)
else:
utils.empirical_risk_minimization(labeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args,
device)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, teacher_output,
optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
self_training_losses = AverageMeter('Self Training Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
self_training_criterion = CrossEntropyLoss().to(device)
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
idx_u, (x_u_strong, _) = next(unlabeled_train_iter)
idx_u = idx_u.to(device)
x_u_strong = x_u_strong.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
y_l = model(x_l)
y_l_strong = model(x_l_strong)
# cross entropy loss
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# self training loss
y_u = teacher_output[idx_u]
y_u_strong = model(x_u_strong)
self_training_loss = args.trade_off_self_training * self_training_criterion(y_u_strong / args.T, y_u / args.T)
self_training_loss.backward()
# measure accuracy and record loss
loss = cls_loss + self_training_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
self_training_losses.update(self_training_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Noisy Student for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
parser.add_argument('--pretrained-teacher', default=None, type=str,
help='pretrained checkpoint of the teacher model')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-self-training', default=1, type=float,
help='the trade-off hyper-parameter of self training loss')
parser.add_argument('--T', default=2, type=float,
help='temperature')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=40, type=int, metavar='N',
help='number of total epochs to run (default: 40)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='noisy_student', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/noisy_student.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --epochs 20 --seed 0 --log logs/noisy_student/cifar100_4_labels_per_class/iter_0
for round in 0 1 2; do
CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-teacher logs/noisy_student/cifar100_4_labels_per_class/iter_$round/checkpoints/latest.pth \
--lr 0.01 --finetune --epochs 40 --T 0.5 --seed 0 --log logs/noisy_student/cifar100_4_labels_per_class/iter_$((round + 1))
done
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --epochs 20 --seed 0 \
--log logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_0
for round in 0 1 2; do
CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--pretrained-teacher logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_$round/checkpoints/latest.pth \
--lr 0.001 --finetune --lr-scheduler cos --epochs 40 --T 1 --seed 0 \
--log logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_$((round + 1))
done
================================================
FILE: examples/semi_supervised_learning/image_classification/pi_model.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.pi_model import sigmoid_warm_up, L2ConsistencyLoss
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = MultipleApply([weak_augment, weak_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
con_losses = AverageMeter('Con Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, con_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
consistency_criterion = L2ConsistencyLoss().to(device)
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
(x_u1, x_u2), _ = next(unlabeled_train_iter)
x_u1 = x_u1.to(device)
x_u2 = x_u2.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# cross entropy loss
y_l = model(x_l)
y_l_strong = model(x_l_strong)
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# consistency loss
y_u1 = model(x_u1)
y_u2 = model(x_u2)
p_u1 = torch.softmax(y_u1, dim=1)
p_u2 = torch.softmax(y_u2, dim=1)
con_loss = args.trade_off_con * sigmoid_warm_up(epoch, args.warm_up_epochs) * consistency_criterion(p_u1, p_u2)
con_loss.backward()
# measure accuracy and record loss
loss = cls_loss + con_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
con_losses.update(con_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Pi Model for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-con', default=0.1, type=float,
help='the trade-off hyper-parameter of consistency loss')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=40, type=int, metavar='N',
help='number of total epochs to run (default: 40)')
parser.add_argument('--warm-up-epochs', default=10, type=int,
help='number of epochs to warm up (default: 10)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='pi_model', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/pi_model.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/pi_model/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/pi_model/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/pi_model/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/pi_model/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/pi_model/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/pi_model/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/pi_model/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/pi_model/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --seed 0 --log logs/pi_model/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --seed 0 --log logs/pi_model/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/pi_model/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python pi_model.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/pseudo_label.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = weak_augment
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
self_training_losses = AverageMeter('Self Training Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs],
prefix="Epoch: [{}]".format(epoch))
self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
x_u, labels_u = next(unlabeled_train_iter)
x_u = x_u.to(device)
labels_u = labels_u.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# cross entropy loss
y_l = model(x_l)
y_l_strong = model(x_l_strong)
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# self training loss
y_u = model(x_u)
self_training_loss, mask, pseudo_labels = self_training_criterion(y_u, y_u)
self_training_loss = args.trade_off_self_training * self_training_loss
self_training_loss.backward()
# measure accuracy and record loss
loss = cls_loss + self_training_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
self_training_losses.update(self_training_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# accuracy of pseudo labels
n_pseudo_labels = mask.sum()
if n_pseudo_labels > 0:
pseudo_labels = pseudo_labels * mask - (1 - mask)
n_correct = (pseudo_labels == labels_u).float().sum()
pseudo_label_acc = n_correct / n_pseudo_labels * 100
pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Pseudo Label for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-self-training', default=1, type=float,
help='the trade-off hyper-parameter of self training loss')
parser.add_argument('--threshold', default=0.95, type=float,
help='confidence threshold (default: 0.95)')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=40, type=int, metavar='N',
help='number of total epochs to run (default: 40)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='pseudo_label', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/pseudo_label.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/pseudo_label/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/pseudo_label/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/pseudo_label_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/requirements.txt
================================================
timm
================================================
FILE: examples/semi_supervised_learning/image_classification/self_tuning.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import utils
from tllib.self_training.self_tuning import Classifier, SelfTuning
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
train_transform = MultipleApply([strong_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('train_transform: ', train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, train_transform,
val_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
num_classes = labeled_train_dataset.num_classes
backbone_q = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
pool_layer = nn.Identity() if args.no_pool else None
classifier_q = Classifier(backbone_q, num_classes, projection_dim=args.projection_dim,
bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier_q)
backbone_k = utils.get_model(args.arch)
classifier_k = Classifier(backbone_k, num_classes, projection_dim=args.projection_dim,
bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to(device)
selftuning = SelfTuning(classifier_q, classifier_k, num_classes, K=args.K, m=args.m, T=args.T).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier_q.get_parameters(args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier_q.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier_q, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, selftuning, optimizer, epoch, args)
# update lr
lr_scheduler.step()
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier_q, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier_q.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, selftuning: SelfTuning,
optimizer: SGD, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
pgc_losses_labeled = AverageMeter('Pgc Loss (Labeled Data)', ':3.2f')
pgc_losses_unlabeled = AverageMeter('Pgc Loss (Unlabeled Data)', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, pgc_losses_labeled, pgc_losses_unlabeled, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# define loss functions
criterion_kl = nn.KLDivLoss(reduction='batchmean').to(device)
# switch to train mode
selftuning.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(l_q, l_k), labels_l = next(labeled_train_iter)
(u_q, u_k), _ = next(unlabeled_train_iter)
l_q, l_k = l_q.to(device), l_k.to(device)
u_q, u_k = u_q.to(device), u_k.to(device)
labels_l = labels_l.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
pgc_logits_labeled, pgc_labels_labeled, y_l = selftuning(l_q, l_k, labels_l)
# cross entropy loss
cls_loss = F.cross_entropy(y_l, labels_l)
# pgc loss on labeled samples
pgc_loss_labeled = criterion_kl(pgc_logits_labeled, pgc_labels_labeled)
(cls_loss + pgc_loss_labeled).backward()
# pgc loss on unlabeled samples
_, y_pred = selftuning.encoder_q(u_q)
_, pseudo_labels = torch.max(y_pred, dim=1)
pgc_logits_unlabeled, pgc_labels_unlabeled, _ = selftuning(u_q, u_k, pseudo_labels)
pgc_loss_unlabeled = criterion_kl(pgc_logits_unlabeled, pgc_labels_unlabeled)
pgc_loss_unlabeled.backward()
# compute gradient and do SGD step
optimizer.step()
# measure accuracy and record loss
cls_losses.update(cls_loss.item(), batch_size)
pgc_losses_labeled.update(pgc_loss_labeled.item(), batch_size)
pgc_losses_unlabeled.update(pgc_loss_unlabeled.item(), batch_size)
loss = cls_loss + pgc_loss_labeled + pgc_loss_unlabeled
losses.update(loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Self Tuning for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--projection-dim', default=1024, type=int,
help='dimension of projection head')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--T', default=0.07, type=float,
help="temperature (default: 0.07)")
parser.add_argument('--K', default=32, type=int,
help="queue size (default: 32)")
parser.add_argument('--m', default=0.999, type=float,
help="momentum coefficient (default: 0.999)")
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-gamma', default=0.1, type=float,
help='parameter for lr scheduler')
parser.add_argument('--milestones', default=[12, 24, 36, 48], type=int, nargs='+',
help='epochs to decay learning rate')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run (default: 60)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='self_tuning', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/self_tuning.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/self_tuning/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/self_tuning/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.01 --finetune --seed 0 --log logs/self_tuning/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.003 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.01 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/uda.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
import utils
from tllib.self_training.uda import StrongWeakConsistencyLoss
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
norm_mean=args.norm_mean, norm_std=args.norm_std)
strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True,
auto_augment=args.auto_augment,
norm_mean=args.norm_mean, norm_std=args.norm_std)
labeled_train_transform = MultipleApply([weak_augment, strong_augment])
unlabeled_train_transform = MultipleApply([weak_augment, strong_augment])
val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std)
print('labeled_train_transform: ', labeled_train_transform)
print('unlabeled_train_transform: ', unlabeled_train_transform)
print('val_transform:', val_transform)
labeled_train_dataset, unlabeled_train_dataset, val_dataset = \
utils.get_dataset(args.data,
args.num_samples_per_class,
args.root, labeled_train_transform,
val_transform,
unlabeled_train_transform=unlabeled_train_transform,
seed=args.seed)
print("labeled_dataset_size: ", len(labeled_train_dataset))
print('unlabeled_dataset_size: ', len(unlabeled_train_dataset))
print("val_dataset_size: ", len(val_dataset))
labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
labeled_train_iter = ForeverDataIterator(labeled_train_loader)
unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone)
num_classes = labeled_train_dataset.num_classes
pool_layer = nn.Identity() if args.no_pool else None
classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer,
finetune=args.finetune).to(device)
print(classifier)
# define optimizer and lr scheduler
if args.lr_scheduler == 'exp':
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True)
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
else:
optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd,
nesterov=True)
lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
print(acc1)
return
# start training
best_acc1 = 0.0
best_avg = 0.0
for epoch in range(args.epochs):
# print lr
print(lr_scheduler.get_lr())
# train for one epoch
train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args)
# evaluate on validation set
acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
best_avg = max(avg, best_avg)
print("best_acc1 = {:3.1f}".format(best_acc1))
print('best_avg = {:3.1f}'.format(best_avg))
logger.close()
def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD,
lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
con_losses = AverageMeter('Con Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, con_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
consistency_criterion = StrongWeakConsistencyLoss(args.threshold, args.T).to(device)
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
(x_u, x_u_strong), _ = next(unlabeled_train_iter)
x_u = x_u.to(device)
x_u_strong = x_u_strong.to(device)
# measure data loading time
data_time.update(time.time() - end)
# clear grad
optimizer.zero_grad()
# compute output
# cross entropy loss
y_l = model(x_l)
y_l_strong = model(x_l_strong)
cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
cls_loss.backward()
# consistency loss
with torch.no_grad():
y_u = model(x_u)
y_u_strong = model(x_u_strong)
con_loss = args.trade_off_con * consistency_criterion(y_u_strong, y_u)
con_loss.backward()
# measure accuracy and record loss
loss = cls_loss + con_loss
losses.update(loss.item(), batch_size)
cls_losses.update(cls_loss.item(), batch_size)
con_losses.update(con_loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# compute gradient and do SGD step
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='UDA for Semi Supervised Learning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA',
help='dataset: ' + ' | '.join(utils.get_dataset_names()))
parser.add_argument('--num-samples-per-class', default=4, type=int,
help='number of labeled samples per class')
parser.add_argument('--train-resizing', default='default', type=str)
parser.add_argument('--val-resizing', default='default', type=str)
parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+',
help='normalization mean')
parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+',
help='normalization std')
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(),
help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)')
parser.add_argument('--bottleneck-dim', default=1024, type=int,
help='dimension of bottleneck')
parser.add_argument('--no-pool', action='store_true', default=False,
help='no pool layer after the feature extractor')
parser.add_argument('--pretrained-backbone', default=None, type=str,
help="pretrained checkpoint of the backbone "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--finetune', action='store_true', default=False,
help='whether to use 10x smaller lr for backbone')
# training parameters
parser.add_argument('--trade-off-cls-strong', default=0.1, type=float,
help='the trade-off hyper-parameter of cls loss on strong augmented labeled data')
parser.add_argument('--trade-off-con', default=1, type=float,
help='the trade-off hyper-parameter of consistency loss')
parser.add_argument('--threshold', default=0.7, type=float,
help='confidence threshold')
parser.add_argument('--T', default=0.85, type=float,
help='temperature')
parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N',
help='mini-batch size (default: 32)')
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr',
help='initial learning rate')
parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'],
help='learning rate decay strategy')
parser.add_argument('--lr-gamma', default=0.0004, type=float,
help='parameter for lr scheduler')
parser.add_argument('--lr-decay', default=0.75, type=float,
help='parameter for lr scheduler')
parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W',
help='weight decay (default:5e-4)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=60, type=int, metavar='N',
help='number of total epochs to run (default: 60)')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='number of iterations per epoch (default: 500)')
parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N',
help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training ')
parser.add_argument("--log", default='uda', type=str,
help="where to save logs, checkpoints and debugging images")
parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'],
help="when phase is 'test', only test the model")
args = parser.parse_args()
main(args)
================================================
FILE: examples/semi_supervised_learning/image_classification/uda.sh
================================================
#!/usr/bin/env bash
# ImageNet Supervised Pretrain (ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python uda.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/uda/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python uda.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python uda.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python uda.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python uda.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python uda.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/uda/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python uda.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python uda.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python uda.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/caltech_4_labels_per_class
# ImageNet Unsupervised Pretrain (MoCov2, ResNet50)
# ======================================================================================================================
# Food 101
CUDA_VISIBLE_DEVICES=0 python uda.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/food101_4_labels_per_class
# ======================================================================================================================
# CIFAR 10
CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cifar10_4_labels_per_class
# ======================================================================================================================
# CIFAR 100
CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \
--norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cifar100_4_labels_per_class
# ======================================================================================================================
# CUB 200
CUDA_VISIBLE_DEVICES=0 python uda.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cub200_4_labels_per_class
# ======================================================================================================================
# Aircraft
CUDA_VISIBLE_DEVICES=0 python uda.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/aircraft_4_labels_per_class
# ======================================================================================================================
# StanfordCars
CUDA_VISIBLE_DEVICES=0 python uda.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/car_4_labels_per_class
# ======================================================================================================================
# SUN397
CUDA_VISIBLE_DEVICES=0 python uda.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/sun_4_labels_per_class
# ======================================================================================================================
# DTD
CUDA_VISIBLE_DEVICES=0 python uda.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/dtd_4_labels_per_class
# ======================================================================================================================
# Oxford Pets
CUDA_VISIBLE_DEVICES=0 python uda.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/pets_4_labels_per_class
# ======================================================================================================================
# Oxford Flowers
CUDA_VISIBLE_DEVICES=0 python uda.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/uda_moco_pretrain/flowers_4_labels_per_class
# ======================================================================================================================
# Caltech 101
CUDA_VISIBLE_DEVICES=0 python uda.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \
--pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \
--lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/caltech_4_labels_per_class
================================================
FILE: examples/semi_supervised_learning/image_classification/utils.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import math
import sys
import time
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataset import Subset, ConcatDataset
import torchvision.transforms as T
import timm
from timm.data.auto_augment import auto_augment_transform, rand_augment_transform
sys.path.append('../../..')
from tllib.modules.classifier import Classifier
import tllib.vision.datasets as datasets
import tllib.vision.models as models
from tllib.vision.transforms import ResizeImage
from tllib.utils.metric import accuracy, ConfusionMatrix
from tllib.utils.meter import AverageMeter, ProgressMeter
def get_model_names():
return sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
) + timm.list_models()
def get_model(model_name, pretrained=True, pretrained_checkpoint=None):
if model_name in models.__dict__:
# load models from common.vision.models
backbone = models.__dict__[model_name](pretrained=pretrained)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=pretrained)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
if pretrained_checkpoint:
print("=> loading pre-trained model from '{}'".format(pretrained_checkpoint))
pretrained_dict = torch.load(pretrained_checkpoint)
backbone.load_state_dict(pretrained_dict, strict=False)
return backbone
def get_dataset_names():
return sorted(
name for name in datasets.__dict__
if not name.startswith("__") and callable(datasets.__dict__[name])
)
def get_dataset(dataset_name, num_samples_per_class, root, labeled_train_transform, val_transform,
unlabeled_train_transform=None, seed=0):
if unlabeled_train_transform is None:
unlabeled_train_transform = labeled_train_transform
if dataset_name == 'OxfordFlowers102':
dataset = datasets.__dict__[dataset_name]
base_dataset = dataset(root=root, split='train', transform=labeled_train_transform, download=True)
# create labeled and unlabeled splits
labeled_idxes, unlabeled_idxes = x_u_split(num_samples_per_class, base_dataset.num_classes,
base_dataset.targets, seed=seed)
# labeled subset
labeled_train_dataset = Subset(base_dataset, labeled_idxes)
labeled_train_dataset.num_classes = base_dataset.num_classes
# unlabeled subset
base_dataset = dataset(root=root, split='train', transform=unlabeled_train_transform, download=True)
unlabeled_train_dataset = ConcatDataset([
Subset(base_dataset, unlabeled_idxes),
dataset(root=root, split='validation', download=True, transform=unlabeled_train_transform)
])
val_dataset = dataset(root=root, split='test', download=True, transform=val_transform)
else:
dataset = datasets.__dict__[dataset_name]
base_dataset = dataset(root=root, split='train', transform=labeled_train_transform, download=True)
# create labeled and unlabeled splits
labeled_idxes, unlabeled_idxes = x_u_split(num_samples_per_class, base_dataset.num_classes,
base_dataset.targets, seed=seed)
# labeled subset
labeled_train_dataset = Subset(base_dataset, labeled_idxes)
labeled_train_dataset.num_classes = base_dataset.num_classes
# unlabeled subset
base_dataset = dataset(root=root, split='train', transform=unlabeled_train_transform, download=True)
unlabeled_train_dataset = Subset(base_dataset, unlabeled_idxes)
val_dataset = dataset(root=root, split='test', download=True, transform=val_transform)
return labeled_train_dataset, unlabeled_train_dataset, val_dataset
def x_u_split(num_samples_per_class, num_classes, labels, seed):
"""
Construct labeled and unlabeled subsets, where the labeled subset is class balanced. Note that the resulting
subsets are **deterministic** with the same random seed.
"""
labels = np.array(labels)
assert num_samples_per_class * num_classes <= len(labels)
random_state = np.random.RandomState(seed)
# labeled subset
labeled_idxes = []
for i in range(num_classes):
ith_class_idxes = np.where(labels == i)[0]
ith_class_idxes = random_state.choice(ith_class_idxes, num_samples_per_class, False)
labeled_idxes.extend(ith_class_idxes)
# unlabeled subset
unlabeled_idxes = [i for i in range(len(labels)) if i not in labeled_idxes]
return labeled_idxes, unlabeled_idxes
def get_train_transform(resizing='default', random_horizontal_flip=True, auto_augment=None,
norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
if resizing == 'default':
transform = T.RandomResizedCrop(224, scale=(0.2, 1.))
elif resizing == 'cifar':
transform = T.Compose([
T.RandomCrop(size=32, padding=4, padding_mode='reflect'),
ResizeImage(224)
])
else:
raise NotImplementedError(resizing)
transforms = [transform]
if random_horizontal_flip:
transforms.append(T.RandomHorizontalFlip())
if auto_augment:
aa_params = dict(
translate_const=int(224 * 0.45),
img_mean=tuple([min(255, round(255 * x)) for x in norm_mean]),
interpolation=Image.BILINEAR
)
if auto_augment.startswith('rand'):
transforms.append(rand_augment_transform(auto_augment, aa_params))
else:
transforms.append(auto_augment_transform(auto_augment, aa_params))
transforms.extend([
T.ToTensor(),
T.Normalize(mean=norm_mean, std=norm_std)
])
return T.Compose(transforms)
def get_val_transform(resizing='default', norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)):
if resizing == 'default':
transform = T.Compose([
ResizeImage(256),
T.CenterCrop(224),
])
elif resizing == 'cifar':
transform = ResizeImage(224)
else:
raise NotImplementedError(resizing)
return T.Compose([
transform,
T.ToTensor(),
T.Normalize(mean=norm_mean, std=norm_std)
])
def convert_dataset(dataset):
"""
Converts a dataset which returns (img, label) pairs into one that returns (index, img, label) triplets.
"""
class DatasetWrapper:
def __init__(self):
self.dataset = dataset
def __getitem__(self, index):
return index, self.dataset[index]
def __len__(self):
return len(self.dataset)
return DatasetWrapper()
class ImageClassifier(Classifier):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
bottleneck[0].weight.data.normal_(0, 0.005)
bottleneck[0].bias.data.fill_(0.1)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
def forward(self, x: torch.Tensor):
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
predictions = self.head(f)
return predictions
def get_cosine_scheduler_with_warmup(optimizer, T_max, num_cycles=7. / 16., num_warmup_steps=0,
last_epoch=-1):
"""
Cosine learning rate scheduler from `FixMatch: Simplifying Semi-Supervised Learning with
Consistency and Confidence (NIPS 2020) `_.
Args:
optimizer (Optimizer): Wrapped optimizer.
T_max (int): Maximum number of iterations.
num_cycles (float): A scalar that controls the shape of cosine function. Default: 7/16.
num_warmup_steps (int): Number of iterations to warm up. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
"""
def _lr_lambda(current_step):
if current_step < num_warmup_steps:
_lr = float(current_step) / float(max(1, num_warmup_steps))
else:
num_cos_steps = float(current_step - num_warmup_steps)
num_cos_steps = num_cos_steps / float(max(1, T_max - num_warmup_steps))
_lr = max(0.0, math.cos(math.pi * num_cycles * num_cos_steps))
return _lr
return LambdaLR(optimizer, _lr_lambda, last_epoch)
def validate(val_loader, model, args, device, num_classes):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1, top5],
prefix='Test: ')
# switch to evaluate mode
model.eval()
confmat = ConfusionMatrix(num_classes)
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = F.cross_entropy(output, target)
# measure accuracy and record loss
acc1, acc5 = accuracy(output, target, topk=(1, 5))
confmat.update(target, output.argmax(1))
losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
top5.update(acc5.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
.format(top1=top1, top5=top5))
acc_global, acc_per_class, iu = confmat.compute()
mean_cls_acc = acc_per_class.mean().item() * 100
print(' * Mean Cls {:.3f}'.format(mean_cls_acc))
return top1.avg, mean_cls_acc
def empirical_risk_minimization(labeled_train_iter, model, optimizer, lr_scheduler, epoch, args, device):
batch_time = AverageMeter('Time', ':2.2f')
data_time = AverageMeter('Data', ':2.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
batch_size = args.batch_size
for i in range(args.iters_per_epoch):
(x_l, x_l_strong), labels_l = next(labeled_train_iter)
x_l = x_l.to(device)
x_l_strong = x_l_strong.to(device)
labels_l = labels_l.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_l = model(x_l)
y_l_strong = model(x_l_strong)
# cross entropy loss on both weak augmented and strong augmented samples
loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l)
# measure accuracy and record loss
losses.update(loss.item(), batch_size)
cls_acc = accuracy(y_l, labels_l)[0]
cls_accs.update(cls_acc.item(), batch_size)
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
================================================
FILE: examples/task_adaptation/image_classification/README.md
================================================
# Task Adaptation for Image Classification
## Installation
Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You
need to install timm to use PyTorch-Image-Models.
```
pip install timm
```
## Dataset
Following datasets can be downloaded automatically:
- [CUB200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
- [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html)
- [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/)
- [StanfordDogs](http://vision.stanford.edu/aditya86/ImageNetDogs/)
- [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/)
- [OxfordFlowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/)
- [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html)
- [PatchCamelyon](https://patchcamelyon.grand-challenge.org/)
- [EuroSAT](https://github.com/phelber/eurosat)
You need to prepare following datasets manually if you want to use them:
- [Retinopathy](https://www.kaggle.com/c/diabetic-retinopathy-detection/data)
- [Resisc45](http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html)
and prepare them following [Documentation for Retinopathy](/common/vision/datasets/retinopathy.py)
and [Resisc45](/common/vision/datasets/resisc45.py).
## Supported Methods
Supported methods include:
- [Explicit inductive bias for transfer learning with convolutional networks
(L2-SP, ICML 2018)](https://arxiv.org/abs/1802.01483)
- [Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning (BSS, NIPS 2019)](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf)
- [DEep Learning Transfer using Fea- ture Map with Attention for convolutional networks (DELTA, ICLR 2019)](https://openreview.net/pdf?id=rkgbwsAcYm)
- [Co-Tuning for Transfer Learning (Co-Tuning, NIPS 2020)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/co-tuning-for-transfer-learning-nips20.pdf)
- [Stochastic Normalization (StochNorm, NIPS 2020)](https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf)
- [Learning Without Forgetting (LWF, ECCV 2016)](https://arxiv.org/abs/1606.09282)
- [Bi-tuning of Pre-trained Representations (Bi-Tuning)](https://arxiv.org/abs/2011.06182?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29)
## Experiment and Results
We follow the common practice in the community as described
in [Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning (BSS, NIPS 2019)](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf)
.
Training iterations and data augmentations are kept the same for different task-adaptation methods for a fair
comparison.
Hyper-parameters of each method are selected by the performance on target validation data.
### Fine-tune the supervised pre-trained model
The shell files give the script to reproduce the supervised pretrained benchmarks with specified hyper-parameters. For
example, if you want to use vanilla fine-tune on CUB200, use the following script
```shell script
# Fine-tune ResNet50 on CUB200.
# Assume you have put the datasets under the path `data/cub200`,
# or you are glad to download the datasets automatically from the Internet to this path
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/erm/cub200_100
```
#### Vision Benchmark on ResNet-50
| | Food101 | CIFAR10 | CIFAR100 | SUN397 | Standford Cars | FGVC Aircraft | DTD | Oxford-IIIT Pets | Caltech-101 | Oxford 102 Flowers | average |
|-----------------|---------|---------|----------|--------|----------------|----------------|-------|------------------|----------------|--------------------|---------|
| Accuracy metric | top1 | top1 | top1 | top1 | top1 | mean per-class | top-1 | mean per-class | mean per-class | mean per-class | |
| Baseline | 85.1 | 96.6 | 84.1 | 63.7 | 87.8 | 80.1 | 70.8 | 93.2 | 91.1 | 93.0 | 84.6 |
| LWF | 83.9 | 96.5 | 83.6 | 64.1 | 87.4 | 82.2 | 72.2 | 94.0 | 89.8 | 92.9 | 84.7 |
| DELTA | 83.8 | 95.9 | 83.7 | 64.5 | 88.1 | 82.3 | 72.2 | 94.2 | 90.1 | 93.1 | 84.8 |
| BSS | 85.0 | 96.6 | 84.2 | 63.5 | 88.4 | 81.8 | 70.2 | 93.3 | 91.6 | 92.7 | 84.7 |
| StochNorm | 85.0 | 96.8 | 83.9 | 63.0 | 87.7 | 81.5 | 71.3 | 93.6 | 90.5 | 92.9 | 84.6 |
| Bi-Tuning | 85.7 | 97.1 | 84.3 | 64.2 | 90.3 | 84.8 | 70.6 | 93.5 | 91.5 | 94.5 | 85.7 |
#### CUB-200-2011 on ResNet-50 (Supervised Pre-trained)
| CUB200 | 15% | 30% | 50% | 100% | Avg |
|-----------|------|------|------|------|------|
| ERM | 51.2 | 64.6 | 74.6 | 81.8 | 68.1 |
| lwf | 56.7 | 66.8 | 73.4 | 81.5 | 69.6 |
| BSS | 53.4 | 66.7 | 76.0 | 82.0 | 69.5 |
| delta | 54.8 | 67.3 | 76.3 | 82.3 | 70.2 |
| StochNorm | 54.8 | 66.8 | 75.8 | 82.2 | 69.9 |
| Co-tuning | 57.6 | 70.1 | 77.3 | 82.5 | 71.9 |
| bi-tuning | 55.8 | 69.3 | 77.2 | 83.1 | 71.4 |
#### Stanford Cars on ResNet-50 (Supervised Pre-trained)
| Standford Cars | 15% | 30% | 50% | 100% | Avg |
|----------------|------|------|------|------|------|
| ERM | 41.1 | 65.9 | 78.4 | 87.8 | 68.3 |
| lwf | 44.9 | 67.0 | 77.6 | 87.5 | 69.3 |
| BSS | 43.3 | 67.6 | 79.6 | 88.0 | 69.6 |
| delta | 45.0 | 68.4 | 79.6 | 88.4 | 70.4 |
| StochNorm | 44.4 | 68.1 | 79.3 | 87.9 | 69.9 |
| Co-tuning | 49.0 | 70.6 | 81.9 | 89.1 | 72.7 |
| bi-tuning | 48.3 | 72.8 | 83.3 | 90.2 | 73.7 |
#### FGVC Aircraft on ResNet-50 (Supervised Pre-trained)
| FGVC Aircraft | 15% | 30% | 50% | 100% | Avg |
|---------------|------|------|------|------|------|
| ERM | 41.6 | 57.8 | 68.7 | 80.2 | 62.1 |
| lwf | 44.1 | 60.6 | 68.7 | 82.4 | 64.0 |
| BSS | 43.6 | 59.5 | 69.6 | 81.2 | 63.5 |
| delta | 44.4 | 61.9 | 71.4 | 82.7 | 65.1 |
| StochNorm | 44.3 | 60.6 | 70.1 | 81.5 | 64.1 |
| Co-tuning | 45.9 | 61.2 | 71.3 | 82.2 | 65.2 |
| bi-tuning | 47.2 | 64.3 | 73.7 | 84.3 | 67.4 |
### Fine-tune the unsupervised pre-trained model
Take MoCo as an example.
1. Download MoCo pretrained checkpoints from https://github.com/facebookresearch/moco
2. Convert the format of the MoCo checkpoints to the standard format of pytorch
```shell
mkdir checkpoints
python convert_moco_to_pretrained.py checkpoints/moco_v1_200ep_pretrain.pth.tar checkpoints/moco_v1_200ep_backbone.pth checkpoints/moco_v1_200ep_fc.pth
```
3. Start training
```shell
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
```
#### CUB-200-2011 on ResNet-50 (MoCo Pre-trained)
| CUB200 | 15% | 30% | 50% | 100% | Avg |
|-----------|------|------|------|------|------|
| ERM | 28.0 | 48.2 | 62.7 | 75.6 | 53.6 |
| lwf | 28.8 | 50.1 | 62.8 | 76.2 | 54.5 |
| BSS | 30.9 | 50.3 | 63.7 | 75.8 | 55.2 |
| delta | 27.9 | 51.4 | 65.9 | 74.6 | 55.0 |
| StochNorm | 20.8 | 44.9 | 60.1 | 72.8 | 49.7 |
| Co-tuning | 29.1 | 50.1 | 63.8 | 75.9 | 54.7 |
| bi-tuning | 32.4 | 51.8 | 65.7 | 76.1 | 56.5 |
#### Stanford Cars on ResNet-50 (MoCo Pre-trained)
| Standford Cars | 15% | 30% | 50% | 100% | Avg |
|----------------|------|------|------|------|------|
| ERM | 42.5 | 71.2 | 83.0 | 90.1 | 71.7 |
| lwf | 44.2 | 71.7 | 82.9 | 90.5 | 72.3 |
| BSS | 45.0 | 71.5 | 83.8 | 90.1 | 72.6 |
| delta | 45.9 | 72.9 | 82.5 | 88.9 | 72.6 |
| StochNorm | 40.3 | 66.2 | 78.0 | 86.2 | 67.7 |
| Co-tuning | 44.2 | 72.6 | 83.3 | 90.3 | 72.6 |
| bi-tuning | 45.6 | 72.8 | 83.2 | 90.8 | 73.1 |
#### FGVC Aircraft on ResNet-50 (MoCo Pre-trained)
| FGVC Aircraft | 15% | 30% | 50% | 100% | Avg |
|---------------|------|------|------|------|------|
| ERM | 45.8 | 67.6 | 78.8 | 88.0 | 70.1 |
| lwf | 48.5 | 68.5 | 78.0 | 87.9 | 70.7 |
| BSS | 47.7 | 69.1 | 79.2 | 88.0 | 71.0 |
| delta | - | - | - | - | - |
| StochNorm | 45.4 | 68.8 | 76.7 | 86.1 | 69.3 |
| Co-tuning | 48.2 | 68.5 | 78.7 | 87.3 | 70.7 |
| bi-tuning | 46.4 | 69.6 | 79.4 | 87.9 | 70.8 |
## Citation
If you use these methods in your research, please consider citing.
```
@inproceedings{LWF,
author = {Zhizhong Li and
Derek Hoiem},
title = {Learning without Forgetting},
booktitle={ECCV},
year = {2016},
}
@inproceedings{L2SP,
title={Explicit inductive bias for transfer learning with convolutional networks},
author={Xuhong, LI and Grandvalet, Yves and Davoine, Franck},
booktitle={ICML},
year={2018},
}
@inproceedings{BSS,
title={Catastrophic forgetting meets negative transfer: Batch spectral shrinkage for safe transfer learning},
author={Chen, Xinyang and Wang, Sinan and Fu, Bo and Long, Mingsheng and Wang, Jianmin},
booktitle={NeurIPS},
year={2019}
}
@inproceedings{DELTA,
title={Delta: Deep learning transfer using feature map with attention for convolutional networks},
author={Li, Xingjian and Xiong, Haoyi and Wang, Hanchao and Rao, Yuxuan and Liu, Liping and Chen, Zeyu and Huan, Jun},
booktitle={ICLR},
year={2019}
}
@inproceedings{StocNorm,
title={Stochastic Normalization},
author={Kou, Zhi and You, Kaichao and Long, Mingsheng and Wang, Jianmin},
booktitle={NeurIPS},
year={2020}
}
@inproceedings{CoTuning,
title={Co-Tuning for Transfer Learning},
author={You, Kaichao and Kou, Zhi and Long, Mingsheng and Wang, Jianmin},
booktitle={NeurIPS},
year={2020}
}
@article{BiTuning,
title={Bi-tuning of Pre-trained Representations},
author={Zhong, Jincheng and Wang, Ximei and Kou, Zhi and Wang, Jianmin and Long, Mingsheng},
journal={arXiv preprint arXiv:2011.06182},
year={2020}
}
```
================================================
FILE: examples/task_adaptation/image_classification/bi_tuning.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import utils
from tllib.vision.transforms import MultipleApply
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
from tllib.regularization.bi_tuning import Classifier, BiTuning
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_augmentation = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)
val_transform = utils.get_val_transform(args.val_resizing)
train_transform = MultipleApply([train_augmentation, train_augmentation])
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,
val_transform, args.sample_rate,
args.num_samples_per_classes)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset)))
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone_q = utils.get_model(args.arch, args.pretrained)
pool_layer = nn.Identity() if args.no_pool else None
classifier_q = Classifier(backbone_q, num_classes, pool_layer=pool_layer, projection_dim=args.projection_dim,
finetune=args.finetune)
if args.pretrained_fc:
print("=> loading pre-trained fc from '{}'".format(args.pretrained_fc))
pretrained_fc_dict = torch.load(args.pretrained_fc)
classifier_q.projector.load_state_dict(pretrained_fc_dict, strict=False)
classifier_q = classifier_q.to(device)
backbone_k = utils.get_model(args.arch)
classifier_k = Classifier(backbone_k, num_classes, pool_layer=pool_layer).to(device)
bituning = BiTuning(classifier_q, classifier_k, num_classes, K=args.K, m=args.m, T=args.T)
# define optimizer and lr scheduler
optimizer = SGD(classifier_q.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier_q.load_state_dict(checkpoint)
acc1 = utils.validate(val_loader, classifier_q, args, device)
print(acc1)
return
# start training
best_acc1 = 0.0
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, bituning, optimizer, epoch, args)
lr_scheduler.step()
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier_q, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier_q.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, bituning: BiTuning, optimizer: SGD, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
cls_losses = AverageMeter('Cls Loss', ':3.2f')
contrastive_losses = AverageMeter('Contrastive Loss', ':3.2f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_losses, contrastive_losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
classifier_criterion = torch.nn.CrossEntropyLoss().to(device)
contrastive_criterion = torch.nn.KLDivLoss(reduction='batchmean').to(device)
# switch to train mode
bituning.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, labels = next(train_iter)
img_q, img_k = x[0], x[1]
img_q = img_q.to(device)
img_k = img_k.to(device)
labels = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y, logits_z, logits_y, bituning_labels = bituning(img_q, img_k, labels)
cls_loss = classifier_criterion(y, labels)
contrastive_loss_z = contrastive_criterion(logits_z, bituning_labels)
contrastive_loss_y = contrastive_criterion(logits_y, bituning_labels)
contrastive_loss = (contrastive_loss_z + contrastive_loss_y)
loss = cls_loss + contrastive_loss * args.trade_off
# measure accuracy and record loss
losses.update(loss.item(), x[0].size(0))
cls_losses.update(cls_loss.item(), x[0].size(0))
contrastive_losses.update(contrastive_loss.item(), x[0].size(0))
cls_acc = accuracy(y, labels)[0]
cls_accs.update(cls_acc.item(), x[0].size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Bi-tuning for Finetuning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')
parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor. Used in models such as ViT.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--pretrained-fc', default=None,
help="pretrained checkpoint of the fc. "
"(default: None)")
parser.add_argument('--T', default=0.07, type=float, help="temperature. (default: 0.07)")
parser.add_argument('--K', type=int, default=40, help="queue size. (default: 40)")
parser.add_argument('--m', type=float, default=0.999, help="momentum coefficient. (default: 0.999)")
parser.add_argument('--projection-dim', type=int, default=128,
help="dimension of the projection head. (default: 128)")
parser.add_argument('--trade-off', type=float, default=1.0, help="trade-off parameters. (default: 1.0)")
# training parameters
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N',
help='mini-batch size (default: 48)')
parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam'])
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='bi_tuning',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/task_adaptation/image_classification/bi_tuning.sh
================================================
#!/usr/bin/env bash
# Supervised Pretraining
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --finetune --seed 0 --log logs/bi_tuning/cub200_100
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 50 --finetune --seed 0 --log logs/bi_tuning/cub200_50
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 30 --finetune --seed 0 --log logs/bi_tuning/cub200_30
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 15 --finetune --seed 0 --log logs/bi_tuning/cub200_15
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 100 --finetune --seed 0 --log logs/bi_tuning/car_100
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 50 --finetune --seed 0 --log logs/bi_tuning/car_50
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 30 --finetune --seed 0 --log logs/bi_tuning/car_30
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 15 --finetune --seed 0 --log logs/bi_tuning/car_15
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/bi_tuning/aircraft_100
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/bi_tuning/aircraft_50
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/bi_tuning/aircraft_30
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/bi_tuning/aircraft_15
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/bi_tuning/cifar10/1e-2 --lr 1e-2
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/bi_tuning/cifar100/1e-2 --lr 1e-2
# Flowers
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/bi_tuning/oxford_flowers102/1e-2 --lr 1e-2
# Pets
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/bi_tuning/oxford_pet/1e-2 --lr 1e-2
# DTD
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/dtd -d DTD --seed 0 --finetune --log logs/bi_tuning/dtd/1e-2 --lr 1e-2
# caltech101
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/bi_tuning/caltech101/lr_1e-3 --lr 1e-3
# SUN397
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/bi_tuning/sun397/lr_1e-2 --lr 1e-2
# Food 101
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/food-101 -d Food101 --seed 0 --finetune --log logs/bi_tuning/food-101/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/bi_tuning/stanford_cars/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/bi_tuning/aircraft/lr_1e-2 --lr 1e-2
# MoCo (Unsupervised Pretraining)
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bi_tuning/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
================================================
FILE: examples/task_adaptation/image_classification/bss.py
================================================
"""
@author: Yifei Ji, Junguang Jiang
@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.regularization.bss import BatchSpectralShrinkage
from tllib.modules.classifier import Classifier
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,
val_transform, args.sample_rate,
args.num_samples_per_classes)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset)))
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, args.pretrained)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device)
bss_module = BatchSpectralShrinkage(k=args.k)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1 = utils.validate(val_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.0
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, bss_module, optimizer, epoch, args)
lr_scheduler.step()
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model: Classifier, bss_module, optimizer: SGD,
epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, labels = next(train_iter)
x = x.to(device)
label = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y, f = model(x)
cls_loss = F.cross_entropy(y, label)
bss_loss = bss_module(f)
loss = cls_loss + args.trade_off * bss_loss
cls_acc = accuracy(y, label)[0]
losses.update(loss.item(), x.size(0))
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='BSS for Finetuning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')
parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor. Used in models such as ViT.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('-k', '--k', default=1, type=int,
metavar='N',
help='hyper-parameter for BSS loss')
parser.add_argument('--trade-off', default=0.001, type=float,
metavar='P', help='trade-off weight of BSS loss')
# training parameters
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N',
help='mini-batch size (default: 48)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='bss',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/task_adaptation/image_classification/bss.sh
================================================
#!/usr/bin/env bash
# Supervised Pretraining
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/bss/cub200_100
CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/bss/cub200_50
CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/bss/cub200_30
CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/bss/cub200_15
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/bss/car_100
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/bss/car_50
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/bss/car_30
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/bss/car_15
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/bss/aircraft_100
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/bss/aircraft_50
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/bss/aircraft_30
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/bss/aircraft_15
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python bss.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/bss/cifar10/1e-2 --lr 1e-2
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python bss.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/bss/cifar100/1e-2 --lr 1e-2
# Flowers
CUDA_VISIBLE_DEVICES=0 python bss.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/bss/oxford_flowers102/1e-2 --lr 1e-2
# Pets
CUDA_VISIBLE_DEVICES=0 python bss.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/bss/oxford_pet/1e-2 --lr 1e-2
# DTD
CUDA_VISIBLE_DEVICES=0 python bss.py data/dtd -d DTD --seed 0 --finetune --log logs/bss/dtd/1e-2 --lr 1e-2
# caltech101
CUDA_VISIBLE_DEVICES=0 python bss.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/bss/caltech101/lr_1e-3 --lr 1e-3
# SUN397
CUDA_VISIBLE_DEVICES=0 python bss.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/bss/sun397/lr_1e-2 --lr 1e-2
# Food 101
CUDA_VISIBLE_DEVICES=0 python bss.py data/food-101 -d Food101 --seed 0 --finetune --log logs/bss/food-101/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/bss/stanford_cars/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/bss/aircraft/lr_1e-2 --lr 1e-2
# MoCo (Unsupervised Pretraining)
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_bss/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
================================================
FILE: examples/task_adaptation/image_classification/co_tuning.py
================================================
"""
@author: Yifei Ji, Junguang Jiang
@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import os
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Subset
import utils
from tllib.regularization.co_tuning import CoTuningLoss, Relationship, Classifier
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.data import ForeverDataIterator
import tllib.vision.datasets as datasets
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_dataset(dataset_name, root, train_transform, val_transform, sample_rate=100, num_samples_per_classes=None):
dataset = datasets.__dict__[dataset_name]
if sample_rate < 100:
train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True,
transform=train_transform)
determin_train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True,
transform=val_transform)
test_dataset = dataset(root=root, split='test', sample_rate=100, download=True, transform=val_transform)
num_classes = train_dataset.num_classes
else:
train_dataset = dataset(root=root, split='train', transform=train_transform)
determin_train_dataset = dataset(root=root, split='train', transform=val_transform)
test_dataset = dataset(root=root, split='test', transform=val_transform)
num_classes = train_dataset.num_classes
if num_samples_per_classes is not None:
samples = list(range(len(train_dataset)))
random.shuffle(samples)
samples_len = min(num_samples_per_classes * num_classes, len(train_dataset))
train_dataset = Subset(train_dataset, samples[:samples_len])
determin_train_dataset = Subset(determin_train_dataset, samples[:samples_len])
return train_dataset, determin_train_dataset, test_dataset, num_classes
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, determin_train_dataset, val_dataset, num_classes = get_dataset(args.data, args.root, train_transform,
val_transform, args.sample_rate,
args.num_samples_per_classes)
print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset)))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
determin_train_loader = DataLoader(determin_train_dataset, batch_size=args.batch_size,
shuffle=False, num_workers=args.workers, drop_last=False)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, args.pretrained)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, head_source=backbone.copy_head(), pool_layer=pool_layer,
finetune=args.finetune).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1 = utils.validate(val_loader, classifier, args, device)
print(acc1)
return
# build relationship between source classes and target classes
source_classifier = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.head_source)
relationship = Relationship(determin_train_loader, source_classifier, device,
os.path.join(logger.root, args.relationship))
co_tuning_loss = CoTuningLoss()
# start training
best_acc1 = 0.0
for epoch in range(args.epochs):
# train for one epoch
train(train_iter, classifier, optimizer, epoch, relationship, co_tuning_loss, args)
lr_scheduler.step()
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,
epoch: int, relationship, co_tuning_loss, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, label_t = next(train_iter)
x = x.to(device)
label_s = torch.from_numpy(relationship[label_t]).cuda().float()
label_t = label_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, y_t = model(x)
tgt_loss = F.cross_entropy(y_t, label_t)
src_loss = co_tuning_loss(y_s, label_s)
loss = tgt_loss + args.trade_off * src_loss
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
cls_acc = accuracy(y_t, label_t)[0]
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Co-Tuning for Finetuning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')
parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor. Used in models such as ViT.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--trade-off', default=2.3, type=float,
metavar='P', help='the trade-off hyper-parameter for co-tuning loss')
parser.add_argument("--relationship", type=str, default='relationship.npy',
help="Where to save relationship file.")
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
# training parameters
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N',
help='mini-batch size (default: 48)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='cotuning',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/task_adaptation/image_classification/co_tuning.sh
================================================
#!/usr/bin/env bash
# Supervised Pretraining
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/co_tuning/cub200_100
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/co_tuning/cub200_50
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/co_tuning/cub200_30
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/co_tuning/cub200_15
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/co_tuning/car_100
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/co_tuning/car_50
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/co_tuning/car_30
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/co_tuning/car_15
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/co_tuning/aircraft_100
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/co_tuning/aircraft_50
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/co_tuning/aircraft_30
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/co_tuning/aircraft_15
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/co_tuning/cifar10/1e-2 --lr 1e-2
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/co_tuning/cifar100/1e-2 --lr 1e-2
# Flowers
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/co_tuning/oxford_flowers102/1e-2 --lr 1e-2
# Pets
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/co_tuning/oxford_pet/1e-2 --lr 1e-2
# DTD
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/dtd -d DTD --seed 0 --finetune --log logs/co_tuning/dtd/1e-2 --lr 1e-2
# caltech101
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/co_tuning/caltech101/lr_1e-3 --lr 1e-3
# SUN397
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/co_tuning/sun397/lr_1e-2 --lr 1e-2
# Food 101
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/food-101 -d Food101 --seed 0 --finetune --log logs/co_tuning/food-101/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/co_tuning/stanford_cars/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/co_tuning/aircraft/lr_1e-2 --lr 1e-2
# MoCo (Unsupervised Pretraining)
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_co_tuning/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
================================================
FILE: examples/task_adaptation/image_classification/convert_moco_to_pretrained.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import sys
import torch
if __name__ == "__main__":
input = sys.argv[1]
obj = torch.load(input, map_location="cpu")
obj = obj["state_dict"]
newmodel = {}
fc = {}
for k, v in obj.items():
if not k.startswith("module.encoder_q."):
continue
old_k = k
k = k.replace("module.encoder_q.", "")
if k.startswith("fc"):
print(k)
fc[k] = v
else:
newmodel[k] = v
with open(sys.argv[2], "wb") as f:
torch.save(newmodel, f)
with open(sys.argv[3], "wb") as f:
torch.save(fc, f)
================================================
FILE: examples/task_adaptation/image_classification/delta.py
================================================
"""
@author: Yifei Ji, Junguang Jiang
@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com
"""
import math
import os
import random
import time
import warnings
import sys
import argparse
import shutil
import numpy as np
from tqdm import tqdm
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.regularization.delta import *
from tllib.modules.classifier import Classifier
from tllib.utils.data import ForeverDataIterator
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,
val_transform, args.sample_rate,
args.num_samples_per_classes)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset)))
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, args.pretrained)
backbone_source = utils.get_model(args.arch, args.pretrained)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device)
source_classifier = Classifier(backbone_source, num_classes=backbone_source.fc.out_features,
head=backbone_source.copy_head(), pool_layer=pool_layer).to(device)
for param in source_classifier.parameters():
param.requires_grad = False
source_classifier.eval()
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1 = utils.validate(val_loader, classifier, args, device)
print(acc1)
return
# create intermediate layer getter
if args.arch == 'resnet50':
return_layers = ['backbone.layer1.2.conv3', 'backbone.layer2.3.conv3', 'backbone.layer3.5.conv3',
'backbone.layer4.2.conv3']
elif args.arch == 'resnet101':
return_layers = ['backbone.layer1.2.conv3', 'backbone.layer2.3.conv3', 'backbone.layer3.5.conv3',
'backbone.layer4.2.conv3']
else:
raise NotImplementedError(args.arch)
source_getter = IntermediateLayerGetter(source_classifier, return_layers=return_layers)
target_getter = IntermediateLayerGetter(classifier, return_layers=return_layers)
# get regularization
if args.regularization_type == 'l2_sp':
backbone_regularization = SPRegularization(source_classifier.backbone, classifier.backbone)
elif args.regularization_type == 'feature_map':
backbone_regularization = BehavioralRegularization()
elif args.regularization_type == 'attention_feature_map':
attention_file = os.path.join(logger.root, args.attention_file)
if not os.path.exists(attention_file):
attention = calculate_channel_attention(train_dataset, return_layers, num_classes, args)
torch.save(attention, attention_file)
else:
print("Loading channel attention from", attention_file)
attention = torch.load(attention_file)
attention = [a.to(device) for a in attention]
backbone_regularization = AttentionBehavioralRegularization(attention)
else:
raise NotImplementedError(args.regularization_type)
head_regularization = L2Regularization(nn.ModuleList([classifier.head, classifier.bottleneck]))
# start training
best_acc1 = 0.0
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, backbone_regularization, head_regularization, target_getter, source_getter,
optimizer, epoch, args)
lr_scheduler.step()
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
logger.close()
def calculate_channel_attention(dataset, return_layers, num_classes, args):
backbone = utils.get_model(args.arch)
classifier = Classifier(backbone, num_classes).to(device)
optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)
data_loader = DataLoader(dataset, batch_size=args.attention_batch_size, shuffle=True,
num_workers=args.workers, drop_last=False)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=math.exp(
math.log(0.1) / args.attention_lr_decay_epochs))
criterion = nn.CrossEntropyLoss()
channel_weights = []
for layer_id, name in enumerate(return_layers):
layer = get_attribute(classifier, name)
layer_channel_weight = [0] * layer.out_channels
channel_weights.append(layer_channel_weight)
# train the classifier
classifier.train()
classifier.backbone.requires_grad = False
print("Pretrain a classifier to calculate channel attention.")
for epoch in range(args.attention_epochs):
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
len(data_loader),
[losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
for i, data in enumerate(data_loader):
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs, _ = classifier(inputs)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
cls_acc = accuracy(outputs, labels)[0]
losses.update(loss.item(), inputs.size(0))
cls_accs.update(cls_acc.item(), inputs.size(0))
if i % args.print_freq == 0:
progress.display(i)
lr_scheduler.step()
# calculate the channel attention
print('Calculating channel attention.')
classifier.eval()
if args.attention_iteration_limit > 0:
total_iteration = min(len(data_loader), args.attention_iteration_limit)
else:
total_iteration = len(args.data_loader)
progress = ProgressMeter(
total_iteration,
[],
prefix="Iteration: ")
for i, data in enumerate(data_loader):
if i >= total_iteration:
break
inputs, labels = data
inputs = inputs.to(device)
labels = labels.to(device)
outputs = classifier(inputs)
loss_0 = criterion(outputs, labels)
progress.display(i)
for layer_id, name in enumerate(tqdm(return_layers)):
layer = get_attribute(classifier, name)
for j in range(layer.out_channels):
tmp = classifier.state_dict()[name + '.weight'][j,].clone()
classifier.state_dict()[name + '.weight'][j,] = 0.0
outputs = classifier(inputs)
loss_1 = criterion(outputs, labels)
difference = loss_1 - loss_0
difference = difference.detach().cpu().numpy().item()
history_value = channel_weights[layer_id][j]
channel_weights[layer_id][j] = 1.0 * (i * history_value + difference) / (i + 1)
classifier.state_dict()[name + '.weight'][j,] = tmp
channel_attention = []
for weight in channel_weights:
weight = np.array(weight)
weight = (weight - np.mean(weight)) / np.std(weight)
weight = torch.from_numpy(weight).float().to(device)
channel_attention.append(F.softmax(weight / 5).detach())
return channel_attention
def train(train_iter: ForeverDataIterator, model: Classifier, backbone_regularization: nn.Module,
head_regularization: nn.Module,
target_getter: IntermediateLayerGetter,
source_getter: IntermediateLayerGetter,
optimizer: SGD, epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f')
losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, losses_reg_head, losses_reg_backbone, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, labels = next(train_iter)
x = x.to(device)
label = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
intermediate_output_s, output_s = source_getter(x)
intermediate_output_t, output_t = target_getter(x)
y, f = output_t
# measure accuracy and record loss
cls_acc = accuracy(y, label)[0]
cls_loss = F.cross_entropy(y, label)
if args.regularization_type == 'feature_map':
loss_reg_backbone = backbone_regularization(intermediate_output_s, intermediate_output_t)
elif args.regularization_type == 'attention_feature_map':
loss_reg_backbone = backbone_regularization(intermediate_output_s, intermediate_output_t)
else:
loss_reg_backbone = backbone_regularization()
loss_reg_head = head_regularization()
loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head
losses_reg_backbone.update(loss_reg_backbone.item() * args.trade_off_backbone, x.size(0))
losses_reg_head.update(loss_reg_head.item() * args.trade_off_head, x.size(0))
losses.update(loss.item(), x.size(0))
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Delta for Finetuning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')
parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor. Used in models such as ViT.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
parser.add_argument('--regularization-type', choices=['l2_sp', 'feature_map', 'attention_feature_map'],
default='attention_feature_map')
parser.add_argument('--trade-off-backbone', default=0.01, type=float,
help='trade-off for backbone regularization')
parser.add_argument('--trade-off-head', default=0.01, type=float,
help='trade-off for head regularization')
# training parameters
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N',
help='mini-batch size (default: 48)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0., type=float,
metavar='W', help='weight decay (default: 0.)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='delta',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model."
"When phase is 'analysis', only analysis the model.")
# parameters for calculating channel attention
parser.add_argument("--attention-file", type=str, default='channel_attention.pt',
help="Where to save and load channel attention file.")
parser.add_argument('--attention-batch-size', default=32, type=int,
metavar='N',
help='mini-batch size for calculating channel attention (default: 32)')
parser.add_argument('--attention-epochs', default=10, type=int, metavar='N',
help='number of epochs to train for training before calculating channel weight')
parser.add_argument('--attention-lr-decay-epochs', default=6, type=int, metavar='N',
help='epochs to decay lr for training before calculating channel weight')
parser.add_argument('--attention-iteration-limit', default=10, type=int, metavar='N',
help='iteration limits for calculating channel attention, -1 means no limits')
args = parser.parse_args()
main(args)
================================================
FILE: examples/task_adaptation/image_classification/delta.sh
================================================
#!/usr/bin/env bash
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/delta/cub200_100
CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/delta/cub200_50
CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/delta/cub200_30
CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/delta/cub200_15
# Stanford Cars
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/delta/car_100
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/delta/car_50
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/delta/car_30
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/delta/car_15
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/delta/aircraft_100
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/delta/aircraft_50
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/delta/aircraft_30
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/delta/aircraft_15
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python delta.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/delta/cifar10/1e-2 --lr 1e-2
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python delta.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/delta/cifar100/1e-2 --lr 1e-2
# Flowers
CUDA_VISIBLE_DEVICES=0 python delta.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/delta/oxford_flowers102/1e-2 --lr 1e-2
# Pets
CUDA_VISIBLE_DEVICES=0 python delta.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/delta/oxford_pet/1e-2 --lr 1e-2
# DTD
CUDA_VISIBLE_DEVICES=0 python delta.py data/dtd -d DTD --seed 0 --finetune --log logs/delta/dtd/1e-2 --lr 1e-2
# caltech101
CUDA_VISIBLE_DEVICES=0 python delta.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/delta/caltech101/lr_1e-3 --lr 1e-3
# SUN397
CUDA_VISIBLE_DEVICES=0 python delta.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/delta/sun397/lr_1e-2 --lr 1e-2
# Food 101
CUDA_VISIBLE_DEVICES=0 python delta.py data/food-101 -d Food101 --seed 0 --finetune --log logs/delta/food-101/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/delta/stanford_cars/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/delta/aircraft/lr_1e-2 --lr 1e-2
# MoCo (Unsupervised Pretraining)
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/cub200_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/cub200_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/cub200_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/cub200_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/cars_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/cars_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/cars_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/cars_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/aircraft_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/aircraft_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/aircraft_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_delta/aircraft_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth
================================================
FILE: examples/task_adaptation/image_classification/erm.py
================================================
"""
@author: Yifei Ji, Junguang Jiang
@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.modules.classifier import Classifier
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,
val_transform, args.sample_rate,
args.num_samples_per_classes)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset)))
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, args.pretrained)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1 = utils.validate(val_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.0
for epoch in range(args.epochs):
logger.set_epoch(epoch)
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, epoch, args)
lr_scheduler.step()
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,
epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, labels = next(train_iter)
x = x.to(device)
label = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y, f = model(x)
cls_loss = F.cross_entropy(y, label)
loss = cls_loss
# measure accuracy and record loss
losses.update(loss.item(), x.size(0))
cls_acc = accuracy(y, label)[0]
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Baseline for Finetuning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')
parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor. Used in models such as ViT.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
# training parameters
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N',
help='mini-batch size (default: 48)')
parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam'])
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='baseline',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/task_adaptation/image_classification/erm.sh
================================================
#!/usr/bin/env bash
# Supervised Pretraining
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/erm/cub200_100
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/erm/cub200_50
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/erm/cub200_30
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/erm/cub200_15
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/erm/car_100
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/erm/car_50
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/erm/car_30
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/erm/car_15
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/erm/aircraft_100
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/erm/aircraft_50
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/erm/aircraft_30
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/erm/aircraft_15
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/erm/cifar10/1e-2 --lr 1e-2
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/erm/cifar100/1e-2 --lr 1e-2
# Flowers
CUDA_VISIBLE_DEVICES=0 python erm.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/erm/oxford_flowers102/1e-2 --lr 1e-2
# Pets
CUDA_VISIBLE_DEVICES=0 python erm.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/erm/oxford_pet/1e-2 --lr 1e-2
# DTD
CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --seed 0 --finetune --log logs/erm/dtd/1e-2 --lr 1e-2
# caltech101
CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/erm/caltech101/lr_1e-3 --lr 1e-3
# SUN397
CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/erm/sun397/lr_1e-2 --lr 1e-2
# Food 101
CUDA_VISIBLE_DEVICES=0 python erm.py data/food-101 -d Food101 --seed 0 --finetune --log logs/erm/food-101/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/erm/stanford_cars/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/erm/aircraft/lr_1e-2 --lr 1e-2
# MoCo (Unsupervised Pretraining)
#CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_erm/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
================================================
FILE: examples/task_adaptation/image_classification/lwf.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F
import utils
from tllib.regularization.lwf import collect_pretrain_labels, Classifier
from tllib.regularization.knowledge_distillation import KnowledgeDistillationLoss
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.logger import CompleteLogger
from tllib.utils.data import ForeverDataIterator, CombineDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,
val_transform, args.sample_rate,
args.num_samples_per_classes)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, drop_last=False)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset)))
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, args.pretrained)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, head_source=backbone.copy_head(), pool_layer=pool_layer,
finetune=args.finetune).to(device)
kd = KnowledgeDistillationLoss(args.T)
source_classifier = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.head_source)
pretrain_labels = collect_pretrain_labels(train_loader, source_classifier, device)
train_dataset = CombineDataset([train_dataset, TensorDataset(pretrain_labels)])
train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.workers, drop_last=True)
train_iter = ForeverDataIterator(train_loader)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1 = utils.validate(val_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.0
for epoch in range(args.epochs):
# train for one epoch
train(train_iter, classifier, kd, optimizer, epoch, args)
lr_scheduler.step()
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model: Classifier, kd, optimizer: SGD,
epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
losses_kd = AverageMeter('Loss (KD)', ':5.4f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, losses_kd, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, label_t, label_s = next(train_iter)
x = x.to(device)
label_s = label_s.to(device)
label_t = label_t.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y_s, y_t = model(x)
tgt_loss = F.cross_entropy(y_t, label_t)
src_loss = kd(y_s, label_s)
loss = tgt_loss + args.trade_off * src_loss
# measure accuracy and record loss
losses.update(tgt_loss.item(), x.size(0))
losses_kd.update(src_loss.item(), x.size(0))
cls_acc = accuracy(y_t, label_t)[0]
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='LWF (Learning without Forgetting) for Finetuning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')
parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor. Used in models such as ViT.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--trade-off', default=4, type=float,
metavar='P', help='weight of pretrained loss')
parser.add_argument("-T", type=float, default=3,
help="temperature for knowledge distillation")
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
# training parameters
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N',
help='mini-batch size (default: 48)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='lwf',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/task_adaptation/image_classification/lwf.sh
================================================
#!/usr/bin/env bash
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/lwf/cub200_100 --lr 0.01
CUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/lwf/cub200_50 --lr 0.001
CUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/lwf/cub200_30 --lr 0.001
CUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/lwf/cub200_15 --lr 0.001
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/lwf/car_100 --lr 0.01
CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/lwf/car_50 --lr 0.01
CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/lwf/car_30 --lr 0.01
CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/lwf/car_15 --lr 0.01
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/lwf/aircraft_100 --lr 0.001
CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/lwf/aircraft_50 --lr 0.001
CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/lwf/aircraft_30 --lr 0.001
CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/lwf/aircraft_15 --lr 0.001
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python lwf.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/lwf/cifar10/1e-2 --lr 1e-2
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python lwf.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/lwf/cifar100/1e-2 --lr 1e-2
# Flowers
CUDA_VISIBLE_DEVICES=0 python lwf.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/lwf/oxford_flowers102/1e-2 --lr 1e-2
# Pets
CUDA_VISIBLE_DEVICES=0 python lwf.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/lwf/oxford_pet/1e-2 --lr 1e-2
# DTD
CUDA_VISIBLE_DEVICES=0 python lwf.py data/dtd -d DTD --seed 0 --finetune --log logs/lwf/dtd/1e-2 --lr 1e-2
# caltech101
CUDA_VISIBLE_DEVICES=0 python lwf.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/lwf/caltech101/lr_1e-3 --lr 1e-3
# SUN397
CUDA_VISIBLE_DEVICES=0 python lwf.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/lwf/sun397/lr_1e-2 --lr 1e-2
# Food 101
CUDA_VISIBLE_DEVICES=0 python lwf.py data/food-101 -d Food101 --seed 0 --finetune --log logs/lwf/food-101/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/lwf/stanford_cars/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/lwf/aircraft/lr_1e-2 --lr 1e-2
================================================
FILE: examples/task_adaptation/image_classification/requirements.txt
================================================
timm
================================================
FILE: examples/task_adaptation/image_classification/stochnorm.py
================================================
"""
@author: Yifei Ji, Junguang Jiang
@contact: jiyf990330@163.com, JiangJunguang1123@outlook.com
"""
import random
import time
import warnings
import argparse
import shutil
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim import SGD
from torch.utils.data import DataLoader
import torch.nn.functional as F
import utils
from tllib.normalization.stochnorm import convert_model
from tllib.modules.classifier import Classifier
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.utils.data import ForeverDataIterator
from tllib.utils.logger import CompleteLogger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def main(args: argparse.Namespace):
logger = CompleteLogger(args.log, args.phase)
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
cudnn.benchmark = True
# Data loading code
train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter)
val_transform = utils.get_val_transform(args.val_resizing)
print("train_transform: ", train_transform)
print("val_transform: ", val_transform)
train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform,
val_transform, args.sample_rate,
args.num_samples_per_classes)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, drop_last=True)
train_iter = ForeverDataIterator(train_loader)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset)))
# create model
print("=> using pre-trained model '{}'".format(args.arch))
backbone = utils.get_model(args.arch, args.pretrained)
pool_layer = nn.Identity() if args.no_pool else None
classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device)
classifier = convert_model(classifier, p=args.prob)
# define optimizer and lr scheduler
optimizer = SGD(classifier.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd,
nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma)
# resume from the best checkpoint
if args.phase == 'test':
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
classifier.load_state_dict(checkpoint)
acc1 = utils.validate(val_loader, classifier, args, device)
print(acc1)
return
# start training
best_acc1 = 0.0
for epoch in range(args.epochs):
print(lr_scheduler.get_lr())
# train for one epoch
train(train_iter, classifier, optimizer, epoch, args)
lr_scheduler.step()
# evaluate on validation set
acc1 = utils.validate(val_loader, classifier, args, device)
# remember best acc@1 and save checkpoint
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
if acc1 > best_acc1:
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
best_acc1 = max(acc1, best_acc1)
print("best_acc1 = {:3.1f}".format(best_acc1))
logger.close()
def train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD,
epoch: int, args: argparse.Namespace):
batch_time = AverageMeter('Time', ':4.2f')
data_time = AverageMeter('Data', ':3.1f')
losses = AverageMeter('Loss', ':3.2f')
cls_accs = AverageMeter('Cls Acc', ':3.1f')
progress = ProgressMeter(
args.iters_per_epoch,
[batch_time, data_time, losses, cls_accs],
prefix="Epoch: [{}]".format(epoch))
# switch to train mode
model.train()
end = time.time()
for i in range(args.iters_per_epoch):
x, labels = next(train_iter)
x = x.to(device)
label = labels.to(device)
# measure data loading time
data_time.update(time.time() - end)
# compute output
y, f = model(x)
cls_loss = F.cross_entropy(y, label)
loss = cls_loss
cls_acc = accuracy(y, label)[0]
losses.update(loss.item(), x.size(0))
cls_accs.update(cls_acc.item(), x.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='StochNorm for Finetuning')
# dataset parameters
parser.add_argument('root', metavar='DIR',
help='root path of dataset')
parser.add_argument('-d', '--data', metavar='DATA')
parser.add_argument('-sr', '--sample-rate', default=100, type=int,
metavar='N',
help='sample rate of training dataset (default: 100)')
parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int,
help='number of samples per classes.')
parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training')
parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation')
parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training')
parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training')
# model parameters
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
choices=utils.get_model_names(),
help='backbone architecture: ' +
' | '.join(utils.get_model_names()) +
' (default: resnet50)')
parser.add_argument('--no-pool', action='store_true',
help='no pool layer after the feature extractor. Used in models such as ViT.')
parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone')
parser.add_argument('--prob', '--probability', default=0.5, type=float,
metavar='P', help='Probability for StochNorm layers')
parser.add_argument('--pretrained', default=None,
help="pretrained checkpoint of the backbone. "
"(default: None, use the ImageNet supervised pretrained backbone)")
# training parameters
parser.add_argument('-b', '--batch-size', default=48, type=int,
metavar='N',
help='mini-batch size (default: 48)')
parser.add_argument('--lr', '--learning-rate', default=0.01, type=float,
metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler')
parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float,
metavar='W', help='weight decay (default: 5e-4)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('--epochs', default=20, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-i', '--iters-per-epoch', default=500, type=int,
help='Number of iterations per epoch')
parser.add_argument('-p', '--print-freq', default=100, type=int,
metavar='N', help='print frequency (default: 100)')
parser.add_argument('--seed', default=None, type=int,
help='seed for initializing training. ')
parser.add_argument("--log", type=str, default='stochnorm',
help="Where to save logs, checkpoints and debugging images.")
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'],
help="When phase is 'test', only test the model.")
args = parser.parse_args()
main(args)
================================================
FILE: examples/task_adaptation/image_classification/stochnorm.sh
================================================
#!/usr/bin/env bash
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/stochnorm/cub200_100
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/stochnorm/cub200_50
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/stochnorm/cub200_30
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/stochnorm/cub200_15
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/stochnorm/car_100
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/stochnorm/car_50
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/stochnorm/car_30
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/stochnorm/car_15
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/stochnorm/aircraft_100
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/stochnorm/aircraft_50
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/stochnorm/aircraft_30
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/stochnorm/aircraft_15
# CIFAR10
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/stochnorm/cifar10/1e-2 --lr 1e-2
# CIFAR100
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/stochnorm/cifar100/1e-2 --lr 1e-2
# Flowers
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/stochnorm/oxford_flowers102/1e-2 --lr 1e-2
# Pets
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/stochnorm/oxford_pet/1e-2 --lr 1e-2
# DTD
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/dtd -d DTD --seed 0 --finetune --log logs/stochnorm/dtd/1e-2 --lr 1e-2
# caltech101
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/stochnorm/caltech101/lr_1e-3 --lr 1e-3
# SUN397
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/stochnorm/sun397/lr_1e-2 --lr 1e-2
# Food 101
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/food-101 -d Food101 --seed 0 --finetune --log logs/stochnorm/food-101/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/stochnorm/stanford_cars/lr_1e-2 --lr 1e-2
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/stochnorm/aircraft/lr_1e-2 --lr 1e-2
# MoCo (Unsupervised Pretraining)
# CUB-200-2011
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Standford Cars
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
# Aircrafts
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth
CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \
--log logs/moco_pretrain_stochnorm/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth
================================================
FILE: examples/task_adaptation/image_classification/utils.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import time
from PIL import Image
import timm
import numpy as np
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Subset
import torchvision.transforms as T
from torch.optim import SGD, Adam
sys.path.append('../../..')
import tllib.vision.datasets as datasets
import tllib.vision.models as models
from tllib.utils.metric import accuracy
from tllib.utils.meter import AverageMeter, ProgressMeter
from tllib.vision.transforms import Denormalize
def get_model_names():
return sorted(
name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name])
) + timm.list_models()
def get_model(model_name, pretrained_checkpoint=None):
if model_name in models.__dict__:
# load models from tllib.vision.models
backbone = models.__dict__[model_name](pretrained=True)
else:
# load models from pytorch-image-models
backbone = timm.create_model(model_name, pretrained=True)
try:
backbone.out_features = backbone.get_classifier().in_features
backbone.reset_classifier(0, '')
backbone.copy_head = backbone.get_classifier
except:
backbone.out_features = backbone.head.in_features
backbone.head = nn.Identity()
backbone.copy_head = lambda x: x.head
if pretrained_checkpoint:
print("=> loading pre-trained model from '{}'".format(pretrained_checkpoint))
pretrained_dict = torch.load(pretrained_checkpoint)
backbone.load_state_dict(pretrained_dict, strict=False)
return backbone
def get_dataset(dataset_name, root, train_transform, val_transform, sample_rate=100, num_samples_per_classes=None):
"""
When sample_rate < 100, e.g. sample_rate = 50, use 50% data to train the model.
Otherwise,
if num_samples_per_classes is not None, e.g. 5, then sample 5 images for each class, and use them to train the model;
otherwise, keep all the data.
"""
dataset = datasets.__dict__[dataset_name]
if sample_rate < 100:
train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True, transform=train_transform)
test_dataset = dataset(root=root, split='test', sample_rate=100, download=True, transform=val_transform)
num_classes = train_dataset.num_classes
else:
train_dataset = dataset(root=root, split='train', download=True, transform=train_transform)
test_dataset = dataset(root=root, split='test', download=True, transform=val_transform)
num_classes = train_dataset.num_classes
if num_samples_per_classes is not None:
samples = list(range(len(train_dataset)))
random.shuffle(samples)
samples_len = min(num_samples_per_classes * num_classes, len(train_dataset))
print("Origin dataset:", len(train_dataset), "Sampled dataset:", samples_len, "Ratio:", float(samples_len) / len(train_dataset))
train_dataset = Subset(train_dataset, samples[:samples_len])
return train_dataset, test_dataset, num_classes
def validate(val_loader, model, args, device, visualize=None) -> float:
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(
len(val_loader),
[batch_time, losses, top1],
prefix='Test: ')
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (images, target) in enumerate(val_loader):
images = images.to(device)
target = target.to(device)
# compute output
output = model(images)
loss = F.cross_entropy(output, target)
# measure accuracy and record loss
acc1, = accuracy(output, target, topk=(1, ))
losses.update(loss.item(), images.size(0))
top1.update(acc1.item(), images.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % args.print_freq == 0:
progress.display(i)
if visualize is not None:
visualize(images[0], "val_{}".format(i))
print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
return top1.avg
def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False):
"""
resizing mode:
- default: take a random resized crop of size 224 with scale in [0.2, 1.];
- res: resize the image to 224;
- res.|crop: resize the image to 256 and take a random crop of size 224;
- res.sma|crop: resize the image keeping its aspect ratio such that the
smaller side is 256, then take a random crop of size 224;
– inc.crop: “inception crop” from (Szegedy et al., 2015);
– cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224.
"""
if resizing == 'default':
transform = T.RandomResizedCrop(224, scale=(0.2, 1.))
elif resizing == 'res.':
transform = T.Resize((224, 224))
elif resizing == 'res.|crop':
transform = T.Compose([
T.Resize((256, 256)),
T.RandomCrop(224)
])
elif resizing == "res.sma|crop":
transform = T.Compose([
T.Resize(256),
T.RandomCrop(224)
])
elif resizing == 'inc.crop':
transform = T.RandomResizedCrop(224)
elif resizing == 'cif.crop':
transform = T.Compose([
T.Resize((224, 224)),
T.Pad(28),
T.RandomCrop(224),
])
else:
raise NotImplementedError(resizing)
transforms = [transform]
if random_horizontal_flip:
transforms.append(T.RandomHorizontalFlip())
if random_color_jitter:
transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5))
transforms.extend([
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return T.Compose(transforms)
def get_val_transform(resizing='default'):
"""
resizing mode:
- default: resize the image to 256 and take the center crop of size 224;
– res.: resize the image to 224
– res.|crop: resize the image such that the smaller side is of size 256 and
then take a central crop of size 224.
"""
if resizing == 'default':
transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
])
elif resizing == 'res.':
transform = T.Resize((224, 224))
elif resizing == 'res.|crop':
transform = T.Compose([
T.Resize((256, 256)),
T.CenterCrop(224),
])
else:
raise NotImplementedError(resizing)
return T.Compose([
transform,
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def get_optimizer(optimizer_name, params, lr, wd, momentum):
'''
Args:
optimizer_name:
- SGD
- Adam
params: iterable of parameters to optimize or dicts defining parameter groups
lr: learning rate
weight_decay: weight decay
momentum: momentum factor for SGD
'''
if optimizer_name == 'SGD':
optimizer = SGD(params=params, lr=lr, momentum=momentum, weight_decay=wd, nesterov=True)
elif optimizer_name == 'Adam':
optimizer = Adam(params=params, lr=lr, weight_decay=wd)
else:
raise NotImplementedError(optimizer_name)
return optimizer
def visualize(image, filename):
"""
Args:
image (tensor): 3 x H x W
filename: filename of the saving image
"""
image = image.detach().cpu()
image = Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
image = image.numpy().transpose((1, 2, 0)) * 255
Image.fromarray(np.uint8(image)).save(filename)
================================================
FILE: requirements.txt
================================================
torch>=1.7.0
torchvision>=0.5.0
numpy
prettytable
tqdm
scikit-learn
webcolors
matplotlib
opencv-python
numba
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
import re
from os import path
here = path.abspath(path.dirname(__file__))
# Get the version string
with open(path.join(here, 'tllib', '__init__.py')) as f:
version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1)
# Get all runtime requirements
REQUIRES = []
with open('requirements.txt') as f:
for line in f:
line, _, _ = line.partition('#')
line = line.strip()
REQUIRES.append(line)
if __name__ == '__main__':
setup(
name="tllib", # Replace with your own username
version=version,
author="THUML",
author_email="JiangJunguang1123@outlook.com",
keywords="domain adaptation, task adaptation, domain generalization, "
"transfer learning, deep learning, pytorch",
description="A Transfer Learning Library for Domain Adaptation, Task Adaptation, and Domain Generalization",
long_description=open('README.md', encoding='utf8').read(),
long_description_content_type="text/markdown",
url="https://github.com/thuml/Transfer-Learning-Library",
packages=find_packages(exclude=['docs', 'examples']),
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 3 - Alpha',
# Indicate who your project is intended for
'Intended Audience :: Science/Research',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development :: Libraries :: Python Modules',
# Pick your license as you wish (should match "license" above)
'License :: OSI Approved :: MIT License',
# Specify the Python versions you support here. In particular, ensure
# that you indicate whether you support Python 2, Python 3 or both.
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
],
python_requires='>=3.6',
install_requires=REQUIRES,
extras_require={
'dev': [
'Sphinx',
'sphinx_rtd_theme',
]
},
)
================================================
FILE: tllib/__init__.py
================================================
from . import alignment
from . import self_training
from . import translation
from . import regularization
from . import utils
from . import vision
from . import modules
from . import ranking
__version__ = '0.4'
__all__ = ['alignment', 'self_training', 'translation', 'regularization', 'utils', 'vision', 'modules', 'ranking']
================================================
FILE: tllib/alignment/__init__.py
================================================
from . import cdan
from . import dann
from . import mdd
from . import dan
from . import jan
from . import mcd
from . import osbp
from . import adda
from . import bsp
================================================
FILE: tllib/alignment/adda.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from typing import Optional, List, Dict
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.classifier import Classifier as ClassifierBase
class DomainAdversarialLoss(nn.Module):
r"""Domain adversarial loss from `Adversarial Discriminative Domain Adaptation (CVPR 2017)
`_.
Similar to the original `GAN `_ paper, ADDA argues that replacing
:math:`\text{log}(1-p)` with :math:`-\text{log}(p)` in the adversarial loss provides better gradient qualities. Detailed
optimization process can be found `here
`_.
Inputs:
- domain_pred (tensor): predictions of domain discriminator
- domain_label (str, optional): whether the data comes from source or target.
Must be 'source' or 'target'. Default: 'source'
Shape:
- domain_pred: :math:`(minibatch,)`.
- Outputs: scalar.
"""
def __init__(self):
super(DomainAdversarialLoss, self).__init__()
def forward(self, domain_pred, domain_label='source'):
assert domain_label in ['source', 'target']
if domain_label == 'source':
return F.binary_cross_entropy(domain_pred, torch.ones_like(domain_pred).to(domain_pred.device))
else:
return F.binary_cross_entropy(domain_pred, torch.zeros_like(domain_pred).to(domain_pred.device))
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
m.eval()
def get_parameters(self, base_lr=1.0, optimize_head=True) -> List[Dict]:
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}
]
if optimize_head:
params.append({"params": self.head.parameters(), "lr": 1.0 * base_lr})
return params
================================================
FILE: tllib/alignment/advent.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np
class Discriminator(nn.Sequential):
"""
Domain discriminator model from
`ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation (CVPR 2019) `_
Distinguish pixel-by-pixel whether the input predictions come from the source domain or the target domain.
The source domain label is 1 and the target domain label is 0.
Args:
num_classes (int): num of classes in the predictions
ndf (int): dimension of the hidden features
Shape:
- Inputs: :math:`(minibatch, C, H, W)` where :math:`C` is the number of classes
- Outputs: :math:`(minibatch, 1, H, W)`
"""
def __init__(self, num_classes, ndf=64):
super(Discriminator, self).__init__(
nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1),
)
def prob_2_entropy(prob):
""" convert probabilistic prediction maps to weighted self-information maps
"""
n, c, h, w = prob.size()
return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c)
def bce_loss(y_pred, y_label):
y_truth_tensor = torch.FloatTensor(y_pred.size())
y_truth_tensor.fill_(y_label)
y_truth_tensor = y_truth_tensor.to(y_pred.get_device())
return F.binary_cross_entropy_with_logits(y_pred, y_truth_tensor)
class DomainAdversarialEntropyLoss(nn.Module):
r"""The `Domain Adversarial Entropy Loss `_
Minimizing entropy with adversarial learning through training a domain discriminator.
Args:
domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts
the domains of predictions. Its input shape is :math:`(minibatch, C, H, W)` and output shape is :math:`(minibatch, 1, H, W)`
Inputs:
- logits (tensor): logits output of segmentation model
- domain_label (str, optional): whether the data comes from source or target.
Choices: ['source', 'target']. Default: 'source'
Shape:
- logits: :math:`(minibatch, C, H, W)` where :math:`C` means the number of classes
- Outputs: scalar.
Examples::
>>> B, C, H, W = 2, 19, 512, 512
>>> discriminator = Discriminator(num_classes=C)
>>> dann = DomainAdversarialEntropyLoss(discriminator)
>>> # logits output on source domain and target domain
>>> y_s, y_t = torch.randn(B, C, H, W), torch.randn(B, C, H, W)
>>> loss = 0.5 * (dann(y_s, "source") + dann(y_t, "target"))
"""
def __init__(self, discriminator: nn.Module):
super(DomainAdversarialEntropyLoss, self).__init__()
self.discriminator = discriminator
def forward(self, logits, domain_label='source'):
"""
"""
assert domain_label in ['source', 'target']
probability = F.softmax(logits, dim=1)
entropy = prob_2_entropy(probability)
domain_prediciton = self.discriminator(entropy)
if domain_label == 'source':
return bce_loss(domain_prediciton, 1)
else:
return bce_loss(domain_prediciton, 0)
def train(self, mode=True):
r"""Sets the discriminator in training mode. In the training mode,
all the parameters in discriminator will be set requires_grad=True.
Args:
mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``.
"""
self.discriminator.train(mode)
for param in self.discriminator.parameters():
param.requires_grad = mode
return self
def eval(self):
r"""Sets the module in evaluation mode. In the training mode,
all the parameters in discriminator will be set requires_grad=False.
This is equivalent with :meth:`self.train(False) `.
"""
return self.train(False)
================================================
FILE: tllib/alignment/bsp.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from typing import Optional
import torch
import torch.nn as nn
from tllib.modules.classifier import Classifier as ClassifierBase
class BatchSpectralPenalizationLoss(nn.Module):
r"""Batch spectral penalization loss from `Transferability vs. Discriminability: Batch
Spectral Penalization for Adversarial Domain Adaptation (ICML 2019)
`_.
Given source features :math:`f_s` and target features :math:`f_t` in current mini batch, singular value
decomposition is first performed
.. math::
f_s = U_s\Sigma_sV_s^T
.. math::
f_t = U_t\Sigma_tV_t^T
Then batch spectral penalization loss is calculated as
.. math::
loss=\sum_{i=1}^k(\sigma_{s,i}^2+\sigma_{t,i}^2)
where :math:`\sigma_{s,i},\sigma_{t,i}` refer to the :math:`i-th` largest singular value of source features
and target features respectively. We empirically set :math:`k=1`.
Inputs:
- f_s (tensor): feature representations on source domain, :math:`f^s`
- f_t (tensor): feature representations on target domain, :math:`f^t`
Shape:
- f_s, f_t: :math:`(N, F)` where F means the dimension of input features.
- Outputs: scalar.
"""
def __init__(self):
super(BatchSpectralPenalizationLoss, self).__init__()
def forward(self, f_s, f_t):
_, s_s, _ = torch.svd(f_s)
_, s_t, _ = torch.svd(f_t)
loss = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2)
return loss
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
================================================
FILE: tllib/alignment/cdan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.classifier import Classifier as ClassifierBase
from tllib.utils.metric import binary_accuracy, accuracy
from tllib.modules.grl import WarmStartGradientReverseLayer
from tllib.modules.entropy import entropy
__all__ = ['ConditionalDomainAdversarialLoss', 'ImageClassifier']
class ConditionalDomainAdversarialLoss(nn.Module):
r"""The Conditional Domain Adversarial Loss used in `Conditional Adversarial Domain Adaptation (NIPS 2018) `_
Conditional Domain adversarial loss measures the domain discrepancy through training a domain discriminator in a
conditional manner. Given domain discriminator :math:`D`, feature representation :math:`f` and
classifier predictions :math:`g`, the definition of CDAN loss is
.. math::
loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(T(f_i^s, g_i^s))] \\
&+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(T(f_j^t, g_j^t))],\\
where :math:`T` is a :class:`MultiLinearMap` or :class:`RandomizedMultiLinearMap` which convert two tensors to a single tensor.
Args:
domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of
features. Its input shape is (N, F) and output shape is (N, 1)
entropy_conditioning (bool, optional): If True, use entropy-aware weight to reweight each training example.
Default: False
randomized (bool, optional): If True, use `randomized multi linear map`. Else, use `multi linear map`.
Default: False
num_classes (int, optional): Number of classes. Default: -1
features_dim (int, optional): Dimension of input features. Default: -1
randomized_dim (int, optional): Dimension of features after randomized. Default: 1024
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
.. note::
You need to provide `num_classes`, `features_dim` and `randomized_dim` **only when** `randomized`
is set True.
Inputs:
- g_s (tensor): unnormalized classifier predictions on source domain, :math:`g^s`
- f_s (tensor): feature representations on source domain, :math:`f^s`
- g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`
- f_t (tensor): feature representations on target domain, :math:`f^t`
Shape:
- g_s, g_t: :math:`(minibatch, C)` where C means the number of classes.
- f_s, f_t: :math:`(minibatch, F)` where F means the dimension of input features.
- Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, )`.
Examples::
>>> from tllib.modules.domain_discriminator import DomainDiscriminator
>>> from tllib.alignment.cdan import ConditionalDomainAdversarialLoss
>>> import torch
>>> num_classes = 2
>>> feature_dim = 1024
>>> batch_size = 10
>>> discriminator = DomainDiscriminator(in_feature=feature_dim * num_classes, hidden_size=1024)
>>> loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')
>>> # features from source domain and target domain
>>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> # logits output from source domain adn target domain
>>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
>>> output = loss(g_s, f_s, g_t, f_t)
"""
def __init__(self, domain_discriminator: nn.Module, entropy_conditioning: Optional[bool] = False,
randomized: Optional[bool] = False, num_classes: Optional[int] = -1,
features_dim: Optional[int] = -1, randomized_dim: Optional[int] = 1024,
reduction: Optional[str] = 'mean', sigmoid=True):
super(ConditionalDomainAdversarialLoss, self).__init__()
self.domain_discriminator = domain_discriminator
self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True)
self.entropy_conditioning = entropy_conditioning
self.sigmoid = sigmoid
self.reduction = reduction
if randomized:
assert num_classes > 0 and features_dim > 0 and randomized_dim > 0
self.map = RandomizedMultiLinearMap(features_dim, num_classes, randomized_dim)
else:
self.map = MultiLinearMap()
self.bce = lambda input, target, weight: F.binary_cross_entropy(input, target, weight,
reduction=reduction) if self.entropy_conditioning \
else F.binary_cross_entropy(input, target, reduction=reduction)
self.domain_discriminator_accuracy = None
def forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:
f = torch.cat((f_s, f_t), dim=0)
g = torch.cat((g_s, g_t), dim=0)
g = F.softmax(g, dim=1).detach()
h = self.grl(self.map(f, g))
d = self.domain_discriminator(h)
weight = 1.0 + torch.exp(-entropy(g))
batch_size = f.size(0)
weight = weight / torch.sum(weight) * batch_size
if self.sigmoid:
d_label = torch.cat((
torch.ones((g_s.size(0), 1)).to(g_s.device),
torch.zeros((g_t.size(0), 1)).to(g_t.device),
))
self.domain_discriminator_accuracy = binary_accuracy(d, d_label)
if self.entropy_conditioning:
return F.binary_cross_entropy(d, d_label, weight.view_as(d), reduction=self.reduction)
else:
return F.binary_cross_entropy(d, d_label, reduction=self.reduction)
else:
d_label = torch.cat((
torch.ones((g_s.size(0), )).to(g_s.device),
torch.zeros((g_t.size(0), )).to(g_t.device),
)).long()
self.domain_discriminator_accuracy = accuracy(d, d_label)
if self.entropy_conditioning:
raise NotImplementedError("entropy_conditioning")
return F.cross_entropy(d, d_label, reduction=self.reduction)
class RandomizedMultiLinearMap(nn.Module):
"""Random multi linear map
Given two inputs :math:`f` and :math:`g`, the definition is
.. math::
T_{\odot}(f,g) = \dfrac{1}{\sqrt{d}} (R_f f) \odot (R_g g),
where :math:`\odot` is element-wise product, :math:`R_f` and :math:`R_g` are random matrices
sampled only once and fixed in training.
Args:
features_dim (int): dimension of input :math:`f`
num_classes (int): dimension of input :math:`g`
output_dim (int, optional): dimension of output tensor. Default: 1024
Shape:
- f: (minibatch, features_dim)
- g: (minibatch, num_classes)
- Outputs: (minibatch, output_dim)
"""
def __init__(self, features_dim: int, num_classes: int, output_dim: Optional[int] = 1024):
super(RandomizedMultiLinearMap, self).__init__()
self.Rf = torch.randn(features_dim, output_dim)
self.Rg = torch.randn(num_classes, output_dim)
self.output_dim = output_dim
def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
f = torch.mm(f, self.Rf.to(f.device))
g = torch.mm(g, self.Rg.to(g.device))
output = torch.mul(f, g) / np.sqrt(float(self.output_dim))
return output
class MultiLinearMap(nn.Module):
"""Multi linear map
Shape:
- f: (minibatch, F)
- g: (minibatch, C)
- Outputs: (minibatch, F * C)
"""
def __init__(self):
super(MultiLinearMap, self).__init__()
def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
batch_size = f.size(0)
output = torch.bmm(g.unsqueeze(2), f.unsqueeze(1))
return output.view(batch_size, -1)
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
================================================
FILE: tllib/alignment/coral.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
import torch.nn as nn
class CorrelationAlignmentLoss(nn.Module):
r"""The `Correlation Alignment Loss` in
`Deep CORAL: Correlation Alignment for Deep Domain Adaptation (ECCV 2016) `_.
Given source features :math:`f_S` and target features :math:`f_T`, the covariance matrices are given by
.. math::
C_S = \frac{1}{n_S-1}(f_S^Tf_S-\frac{1}{n_S}(\textbf{1}^Tf_S)^T(\textbf{1}^Tf_S))
.. math::
C_T = \frac{1}{n_T-1}(f_T^Tf_T-\frac{1}{n_T}(\textbf{1}^Tf_T)^T(\textbf{1}^Tf_T))
where :math:`\textbf{1}` denotes a column vector with all elements equal to 1, :math:`n_S, n_T` denotes number of
source and target samples, respectively. We use :math:`d` to denote feature dimension, use
:math:`{\Vert\cdot\Vert}^2_F` to denote the squared matrix `Frobenius norm`. The correlation alignment loss is
given by
.. math::
l_{CORAL} = \frac{1}{4d^2}\Vert C_S-C_T \Vert^2_F
Inputs:
- f_s (tensor): feature representations on source domain, :math:`f^s`
- f_t (tensor): feature representations on target domain, :math:`f^t`
Shape:
- f_s, f_t: :math:`(N, d)` where d means the dimension of input features, :math:`N=n_S=n_T` is mini-batch size.
- Outputs: scalar.
"""
def __init__(self):
super(CorrelationAlignmentLoss, self).__init__()
def forward(self, f_s: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor:
mean_s = f_s.mean(0, keepdim=True)
mean_t = f_t.mean(0, keepdim=True)
cent_s = f_s - mean_s
cent_t = f_t - mean_t
cov_s = torch.mm(cent_s.t(), cent_s) / (len(f_s) - 1)
cov_t = torch.mm(cent_t.t(), cent_t) / (len(f_t) - 1)
mean_diff = (mean_s - mean_t).pow(2).mean()
cov_diff = (cov_s - cov_t).pow(2).mean()
return mean_diff + cov_diff
================================================
FILE: tllib/alignment/d_adapt/__init__.py
================================================
================================================
FILE: tllib/alignment/d_adapt/feedback.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import itertools
import numpy as np
import copy
import logging
from typing import List, Optional, Union
import torch
from detectron2.config import configurable
from detectron2.structures import BoxMode, Boxes, Instances
from detectron2.data.catalog import DatasetCatalog, MetadataCatalog
from detectron2.data.build import filter_images_with_only_crowd_annotations, filter_images_with_few_keypoints, \
print_instances_class_histogram
from detectron2.data.detection_utils import check_metadata_consistency
import detectron2.data.transforms as T
import detectron2.data.detection_utils as utils
from .proposal import Proposal
def load_feedbacks_into_dataset(dataset_dicts, proposals_list: List[Proposal]):
"""
Load precomputed object feedbacks into the dataset.
Args:
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format.
proposals_list (list[Proposal]): list of Proposal.
Returns:
list[dict]: the same format as dataset_dicts, but added feedback field.
"""
feedbacks = {}
for record in dataset_dicts:
image_id = str(record["image_id"])
feedbacks[image_id] = {
'pred_boxes': [],
'pred_classes': [],
}
for proposals in proposals_list:
image_id = str(proposals.image_id)
feedbacks[image_id]['pred_boxes'] += proposals.pred_boxes.tolist()
feedbacks[image_id]['pred_classes'] += proposals.pred_classes.tolist()
# Assuming default bbox_mode of precomputed feedbacks are 'XYXY_ABS'
bbox_mode = BoxMode.XYXY_ABS
dataset_dicts_with_feedbacks = []
for record in dataset_dicts:
# Get the index of the feedback
image_id = str(record["image_id"])
record["feedback_proposal_boxes"] = feedbacks[image_id]["pred_boxes"]
record["feedback_gt_classes"] = feedbacks[image_id]["pred_classes"]
record["feedback_gt_boxes"] = feedbacks[image_id]["pred_boxes"]
record["feedback_bbox_mode"] = bbox_mode
if sum(map(lambda x: x >= 0, feedbacks[image_id]["pred_classes"])) > 0: # remove images without feedbacks
dataset_dicts_with_feedbacks.append(record)
return dataset_dicts_with_feedbacks
def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, proposals_list=None):
"""
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation.
Args:
names (str or list[str]): a dataset name or a list of dataset names
filter_empty (bool): whether to filter out images without instance annotations
min_keypoints (int): filter out images with fewer keypoints than
`min_keypoints`. Set to 0 to do nothing.
proposals_list (optional, list[Proposal]): list of Proposal.
Returns:
list[dict]: a list of dicts following the standard dataset dict format.
"""
if isinstance(names, str):
names = [names]
assert len(names), names
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names]
for dataset_name, dicts in zip(names, dataset_dicts):
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name)
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts))
if proposals_list is not None:
# load precomputed feedbacks for each proposals
dataset_dicts = load_feedbacks_into_dataset(dataset_dicts, proposals_list)
has_instances = "annotations" in dataset_dicts[0]
if filter_empty and has_instances:
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts)
if min_keypoints > 0 and has_instances:
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints)
if has_instances:
try:
class_names = MetadataCatalog.get(names[0]).thing_classes
check_metadata_consistency("thing_classes", names)
print_instances_class_histogram(dataset_dicts, class_names)
except AttributeError: # class names are not available for this dataset
pass
assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names))
return dataset_dicts
def transform_feedbacks(dataset_dict, image_shape, transforms, *, min_box_size=0):
"""
Apply transformations to the feedbacks in dataset_dict, if any.
Args:
dataset_dict (dict): a dict read from the dataset, possibly
contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode"
image_shape (tuple): height, width
transforms (TransformList):
min_box_size (int): proposals with either side smaller than this
threshold are removed
The input dict is modified in-place, with abovementioned keys removed. A new
key "proposals" will be added. Its value is an `Instances`
object which contains the transformed proposals in its field
"proposal_boxes" and "objectness_logits".
"""
if "feedback_proposal_boxes" in dataset_dict:
# Transform proposal boxes
proposal_boxes = transforms.apply_box(
BoxMode.convert(
dataset_dict.pop("feedback_proposal_boxes"),
dataset_dict.get("feedback_bbox_mode"),
BoxMode.XYXY_ABS,
)
)
proposal_boxes = Boxes(proposal_boxes)
gt_boxes = transforms.apply_box(
BoxMode.convert(
dataset_dict.pop("feedback_gt_boxes"),
dataset_dict.get("feedback_bbox_mode"),
BoxMode.XYXY_ABS,
)
)
gt_boxes = Boxes(gt_boxes)
gt_classes = torch.as_tensor(
dataset_dict.pop("feedback_gt_classes")
)
proposal_boxes.clip(image_shape)
gt_boxes.clip(image_shape)
keep = proposal_boxes.nonempty(threshold=min_box_size) & (gt_classes >= 0)
# keep = boxes.nonempty(threshold=min_box_size)
proposal_boxes = proposal_boxes[keep]
gt_boxes = gt_boxes[keep]
gt_classes = gt_classes[keep]
feedbacks = Instances(image_shape)
feedbacks.proposal_boxes = proposal_boxes
feedbacks.gt_boxes = gt_boxes
feedbacks.gt_classes = gt_classes
dataset_dict["feedbacks"] = feedbacks
class DatasetMapper:
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by the model.
This is the default callable to be used to map your dataset dict into training data.
You may need to follow it to implement your own one for customized logic,
such as a different way to read or transform images.
See :doc:`/tutorials/data_loading` for details.
The callable currently does the following:
1. Read the image from "file_name"
2. Applies cropping/geometric transforms to the image and annotations
3. Prepare data and annotations to Tensor and :class:`Instances`
"""
@configurable
def __init__(
self,
is_train: bool,
*,
augmentations: List[Union[T.Augmentation, T.Transform]],
image_format: str,
use_instance_mask: bool = False,
use_keypoint: bool = False,
instance_mask_format: str = "polygon",
keypoint_hflip_indices: Optional[np.ndarray] = None,
precomputed_proposal_topk: Optional[int] = None,
recompute_boxes: bool = False,
):
"""
NOTE: this interface is experimental.
Args:
is_train: whether it's used in training or inference
augmentations: a list of augmentations or deterministic transforms to apply
image_format: an image format supported by :func:`detection_utils.read_image`.
use_instance_mask: whether to process instance segmentation annotations, if available
use_keypoint: whether to process keypoint annotations if available
instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation
masks into this format.
keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices`
precomputed_proposal_topk: if given, will load pre-computed
proposals from dataset_dict and keep the top k proposals for each image.
recompute_boxes: whether to overwrite bounding box annotations
by computing tight bounding boxes from instance mask annotations.
"""
if recompute_boxes:
assert use_instance_mask, "recompute_boxes requires instance masks"
# fmt: off
self.is_train = is_train
self.augmentations = T.AugmentationList(augmentations)
self.image_format = image_format
self.use_instance_mask = use_instance_mask
self.instance_mask_format = instance_mask_format
self.use_keypoint = use_keypoint
self.keypoint_hflip_indices = keypoint_hflip_indices
self.proposal_topk = precomputed_proposal_topk
self.recompute_boxes = recompute_boxes
# fmt: on
logger = logging.getLogger(__name__)
mode = "training" if is_train else "inference"
logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}")
@classmethod
def from_config(cls, cfg, is_train: bool = True):
augs = utils.build_augmentation(cfg, is_train)
if cfg.INPUT.CROP.ENABLED and is_train:
augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE))
recompute_boxes = cfg.MODEL.MASK_ON
else:
recompute_boxes = False
ret = {
"is_train": is_train,
"augmentations": augs,
"image_format": cfg.INPUT.FORMAT,
"use_instance_mask": cfg.MODEL.MASK_ON,
"instance_mask_format": cfg.INPUT.MASK_FORMAT,
"use_keypoint": cfg.MODEL.KEYPOINT_ON,
"recompute_boxes": recompute_boxes,
}
if cfg.MODEL.KEYPOINT_ON:
ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN)
if cfg.MODEL.LOAD_PROPOSALS:
ret["precomputed_proposal_topk"] = (
cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN
if is_train
else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST
)
return ret
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
# USER: Write your own image loading if it's not from a file
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
utils.check_image_size(dataset_dict, image)
# USER: Remove if you don't do semantic/panoptic segmentation.
if "sem_seg_file_name" in dataset_dict:
sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2)
else:
sem_seg_gt = None
aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
transforms = self.augmentations(aug_input)
image, sem_seg_gt = aug_input.image, aug_input.sem_seg
image_shape = image.shape[:2] # h, w
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
if sem_seg_gt is not None:
dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long"))
# USER: Remove if you don't use pre-computed proposals.
# Most users would not need this feature.
if self.proposal_topk is not None:
utils.transform_proposals(
dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk
)
transform_feedbacks(
dataset_dict, image_shape, transforms
)
if not self.is_train:
# USER: Modify this if you want to keep them for some reason.
dataset_dict.pop("annotations", None)
dataset_dict.pop("sem_seg_file_name", None)
return dataset_dict
if "annotations" in dataset_dict:
# USER: Modify this if you want to keep them for some reason.
for anno in dataset_dict["annotations"]:
if not self.use_instance_mask:
anno.pop("segmentation", None)
if not self.use_keypoint:
anno.pop("keypoints", None)
# USER: Implement additional transformations if you have other types of data
annos = [
utils.transform_instance_annotations(
obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices
)
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(
annos, image_shape, mask_format=self.instance_mask_format
)
# After transforms such as cropping are applied, the bounding box may no longer
# tightly bound the object. As an example, imagine a triangle object
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
# the intersection of original bounding box and the cropping box.
if self.recompute_boxes:
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict
================================================
FILE: tllib/alignment/d_adapt/modeling/__init__.py
================================================
from . import meta_arch
from . import roi_heads
================================================
FILE: tllib/alignment/d_adapt/modeling/matcher.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
from torch import Tensor, nn
from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple
class MaxOverlapMatcher(object):
"""
This class assigns to each predicted "element" (e.g., a box) a ground-truth
element. Each predicted element will have exactly zero or one matches; each
ground-truth element may be matched to one predicted elements.
"""
def __init__(self):
pass
def __call__(self, match_quality_matrix):
"""
Args:
match_quality_matrix (Tensor[float]): an MxN tensor, containing the
pairwise quality between M ground-truth elements and N predicted
elements. All elements must be >= 0 (due to the us of `torch.nonzero`
for selecting indices in :meth:`set_low_quality_matches_`).
Returns:
matches (Tensor[int64]): a vector of length N, where matches[i] is a matched
ground-truth index in [0, M)
match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates
whether a prediction is a true or false positive or ignored
"""
assert match_quality_matrix.dim() == 2
# match_quality_matrix is M (gt) x N (predicted)
# Max over gt elements (dim 0) to find best gt candidate for each prediction
_, matched_idxs = match_quality_matrix.max(dim=0)
anchor_labels = match_quality_matrix.new_full(
(match_quality_matrix.size(1),), -1, dtype=torch.int8
)
# For each gt, find the prediction with which it has highest quality
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
# Find the highest quality match available, even if it is low, including ties.
# Note that the matches qualities must be positive due to the use of
# `torch.nonzero`.
_, pred_inds_with_highest_quality = nonzero_tuple(
match_quality_matrix == highest_quality_foreach_gt[:, None]
)
anchor_labels[pred_inds_with_highest_quality] = 1
return matched_idxs, anchor_labels
================================================
FILE: tllib/alignment/d_adapt/modeling/meta_arch/__init__.py
================================================
from .rcnn import DecoupledGeneralizedRCNN
from .retinanet import DecoupledRetinaNet
================================================
FILE: tllib/alignment/d_adapt/modeling/meta_arch/rcnn.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
from typing import Optional, Callable, Tuple, Any, List, Sequence, Dict
import numpy as np
from detectron2.utils.events import get_event_storage
from detectron2.structures import Instances
from detectron2.data.detection_utils import convert_image_to_rgb
from detectron2.modeling.postprocessing import detector_postprocess
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from tllib.vision.models.object_detection.meta_arch import TLGeneralizedRCNN
@META_ARCH_REGISTRY.register()
class DecoupledGeneralizedRCNN(TLGeneralizedRCNN):
"""
Generalized R-CNN for Decoupled Adaptation (D-adapt).
Similar to that in in Supervised Learning, DecoupledGeneralizedRCNN has the following three components:
1. Per-image feature extraction (aka backbone)
2. Region proposal generation
3. Per-region feature extraction and prediction
Different from that in Supervised Learning, DecoupledGeneralizedRCNN
1. accepts unlabeled images and uses the feedbacks from adaptors as supervision during training
2. generate foreground and background proposals during inference
Args:
backbone: a backbone module, must follow detectron2's backbone interface
proposal_generator: a module that generates proposals using backbone features
roi_heads: a ROI head that performs per-region computation
pixel_mean, pixel_std: list or tuple with #channels element,
representing the per-channel mean and std to be used to normalize
the input image
input_format: describe the meaning of channels of input. Needed by visualization
vis_period: the period to run visualization. Set to 0 to disable.
finetune (bool): whether finetune the detector or train from scratch. Default: True
Inputs:
- batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances (optional): groundtruth :class:`Instances`
* feedbacks (optional): :class:`Instances`, feedbacks from adaptors.
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
- labeled (bool, optional): whether has ground-truth label
Outputs:
- outputs (during inference): A list of dict where each dict is the output for one input image.
The dict contains a key "instances" whose value is a :class:`Instances`.
The :class:`Instances` object has the following keys:
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
- losses (during training): A dict of different losses
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]], labeled=True):
if not self.training:
return self.inference(batched_inputs)
images = self.preprocess_image(batched_inputs)
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
else:
gt_instances = None
features = self.backbone(images.tensor)
if "feedbacks" in batched_inputs[0]:
feedbacks = [x["feedbacks"].to(self.device) for x in batched_inputs]
else:
feedbacks = None
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances, labeled=labeled)
_, _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, feedbacks, labeled=labeled)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
self.visualize_training(batched_inputs, proposals, feedbacks)
return losses
def visualize_training(self, batched_inputs, proposals, feedbacks=None):
"""
A function used to visualize images and proposals. It shows ground truth
bounding boxes on the original image and up to 20 top-scoring predicted
object proposals on the original image. Users can implement different
visualization functions for different models.
Args:
batched_inputs (list): a list that contains input to the model.
proposals (list): a list that contains predicted proposals. Both
batched_inputs and proposals should have the same length.
feedbacks (list): a list that contains feedbacks from adaptors. Both
batched_inputs and feedbacks should have the same length.
"""
from detectron2.utils.visualizer import Visualizer
storage = get_event_storage()
max_vis_prop = 20
for input, prop in zip(batched_inputs, proposals):
img = input["image"]
img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format)
v_gt = Visualizer(img, None)
v_gt = v_gt.overlay_instances(boxes=input["instances"].gt_boxes)
anno_img = v_gt.get_image()
box_size = min(len(prop.proposal_boxes), max_vis_prop)
v_pred = Visualizer(img, None)
v_pred = v_pred.overlay_instances(
boxes=prop.proposal_boxes[0:box_size].tensor.cpu().numpy()
)
prop_img = v_pred.get_image()
num_classes = self.roi_heads.box_predictor.num_classes
if feedbacks is not None:
v_feedback_gt = Visualizer(img, None)
instance = feedbacks[0].to(torch.device("cpu"))
v_feedback_gt = v_feedback_gt.overlay_instances(
boxes=instance.proposal_boxes[instance.gt_classes != num_classes])
feedback_gt_img = v_feedback_gt.get_image()
v_feedback_gf = Visualizer(img, None)
v_feedback_gf = v_feedback_gf.overlay_instances(
boxes=instance.proposal_boxes[instance.gt_classes == num_classes])
feedback_gf_img = v_feedback_gf.get_image()
vis_img = np.vstack((anno_img, prop_img, feedback_gt_img, feedback_gf_img))
vis_img = vis_img.transpose(2, 0, 1)
vis_name = f"Top: GT; Middle: Pred; Bottom: Feedback GT, Feedback GF"
else:
vis_img = np.concatenate((anno_img, prop_img), axis=1)
vis_img = vis_img.transpose(2, 0, 1)
vis_name = "Left: GT bounding boxes; Right: Predicted proposals"
storage.put_image(vis_name, vis_img)
break # only visualize one image in a batch
def inference(
self,
batched_inputs: Tuple[Dict[str, torch.Tensor]],
detected_instances: Optional[List[Instances]] = None,
do_postprocess: bool = True,
):
"""
Run inference on the given inputs.
Args:
batched_inputs (list[dict]): same as in :meth:`forward`
detected_instances (None or list[Instances]): if not None, it
contains an `Instances` object per image. The `Instances`
object contains "pred_boxes" and "pred_classes" which are
known boxes in the image.
The inference will then skip the detection of bounding boxes,
and only predict other per-ROI outputs.
do_postprocess (bool): whether to apply post-processing on the outputs.
Returns:
When do_postprocess=True, same as in :meth:`forward`.
Otherwise, a list[Instances] containing raw network outputs.
"""
assert not self.training
images = self.preprocess_image(batched_inputs)
features = self.backbone(images.tensor)
proposals, _ = self.proposal_generator(images, features, None)
results, background_results, _ = self.roi_heads(images, features, proposals, None)
processed_results = []
for results_per_image, background_results_per_image, input_per_image, image_size in zip(
results, background_results, batched_inputs, images.image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
background_r = detector_postprocess(background_results_per_image, height, width)
processed_results.append({"instances": r, 'background': background_r})
return processed_results
================================================
FILE: tllib/alignment/d_adapt/modeling/meta_arch/retinanet.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, Callable, Tuple, Any, List, Sequence, Dict
import random
import numpy as np
import torch
from torch import Tensor
from detectron2.structures import BoxMode, Boxes, Instances, pairwise_iou, ImageList
from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple
from detectron2.modeling import detector_postprocess
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.data.detection_utils import convert_image_to_rgb
from detectron2.utils.events import get_event_storage
from tllib.vision.models.object_detection.meta_arch import TLRetinaNet
from ..matcher import MaxOverlapMatcher
@META_ARCH_REGISTRY.register()
class DecoupledRetinaNet(TLRetinaNet):
"""
RetinaNet for Decoupled Adaptation (D-adapt).
Different from that in Supervised Learning, DecoupledRetinaNet
1. accepts unlabeled images and uses the feedbacks from adaptors as supervision during training
2. generate foreground and background proposals during inference
Args:
backbone: a backbone module, must follow detectron2's backbone interface
head (nn.Module): a module that predicts logits and regression deltas
for each level from a list of per-level features
head_in_features (Tuple[str]): Names of the input feature maps to be used in head
anchor_generator (nn.Module): a module that creates anchors from a
list of features. Usually an instance of :class:`AnchorGenerator`
box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to
instance boxes
anchor_matcher (Matcher): label the anchors by matching them with ground truth.
num_classes (int): number of classes. Used to label background proposals.
# Loss parameters:
focal_loss_alpha (float): focal_loss_alpha
focal_loss_gamma (float): focal_loss_gamma
smooth_l1_beta (float): smooth_l1_beta
box_reg_loss_type (str): Options are "smooth_l1", "giou"
# Inference parameters:
test_score_thresh (float): Inference cls score threshold, only anchors with
score > INFERENCE_TH are considered for inference (to improve speed)
test_topk_candidates (int): Select topk candidates before NMS
test_nms_thresh (float): Overlap threshold used for non-maximum suppression
(suppress boxes with IoU >= this threshold)
max_detections_per_image (int):
Maximum number of detections to return per image during inference
(100 is based on the limit established for the COCO dataset).
# Input parameters
pixel_mean (Tuple[float]):
Values to be used for image normalization (BGR order).
To train on images of different number of channels, set different mean & std.
Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
pixel_std (Tuple[float]):
When using pre-trained models in Detectron1 or any MSRA models,
std has been absorbed into its conv1 weights, so the std needs to be set 1.
Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
vis_period (int):
The period (in terms of steps) for minibatch visualization at train time.
Set to 0 to disable.
input_format (str): Whether the model needs RGB, YUV, HSV etc.
finetune (bool): whether finetune the detector or train from scratch. Default: True
Inputs:
- batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances (optional): groundtruth :class:`Instances`
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
- labeled (bool, optional): whether has ground-truth label
Outputs:
- outputs: A list of dict where each dict is the output for one input image.
The dict contains a key "instances" whose value is a :class:`Instances`
and a key "features" whose value is the features of middle layers.
The :class:`Instances` object has the following keys:
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
- losses: A dict of different losses
"""
def __init__(self, *args, max_samples_per_level=25, **kwargs):
super(DecoupledRetinaNet, self).__init__(*args, **kwargs)
self.max_samples_per_level = max_samples_per_level
self.max_matcher = MaxOverlapMatcher()
def forward_training(self, images, features, predictions, gt_instances=None, feedbacks=None, labeled=True):
# Transpose the Hi*Wi*A dimension to the middle:
pred_logits, pred_anchor_deltas = self._transpose_dense_predictions(
predictions, [self.num_classes, 4]
)
anchors = self.anchor_generator(features)
if labeled:
gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances)
losses = self.losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes)
else:
proposal_labels, proposal_boxes = self.label_pseudo_anchors(anchors, feedbacks)
losses = self.losses(anchors, pred_logits, proposal_labels, pred_anchor_deltas, proposal_boxes)
losses.pop('loss_box_reg')
return losses
def forward(self, batched_inputs: Tuple[Dict[str, Tensor]], labeled=True):
images = self.preprocess_image(batched_inputs)
features = self.backbone(images.tensor)
features = [features[f] for f in self.head_in_features]
predictions = self.head(features)
if self.training:
if "instances" in batched_inputs[0]:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
else:
gt_instances = None
if "feedbacks" in batched_inputs[0]:
feedbacks = [x["feedbacks"].to(self.device) for x in batched_inputs]
else:
feedbacks = None
losses = self.forward_training(images, features, predictions, gt_instances, feedbacks, labeled)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
results = self.forward_inference(images, features, predictions)
self.visualize_training(batched_inputs, results, feedbacks)
return losses
else:
# sample_background must be called before inference
# since inference will change predictions
background_results = self.sample_background(images, features, predictions)
results = self.forward_inference(images, features, predictions)
processed_results = []
for results_per_image, background_results_per_image, input_per_image, image_size in zip(
results, background_results, batched_inputs, images.image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
background_r = detector_postprocess(background_results_per_image, height, width)
processed_results.append({"instances": r, "background": background_r})
return processed_results
@torch.no_grad()
def label_pseudo_anchors(self, anchors, instances):
"""
Args:
anchors (list[Boxes]): A list of #feature level Boxes.
The Boxes contains anchors of this image on the specific feature level.
instances (list[Instances]): a list of N `Instances`s. The i-th
`Instances` contains the ground-truth per-instance annotations
for the i-th input image.
Returns:
list[Tensor]:
List of #img tensors. i-th element is a vector of labels whose length is
the total number of anchors across all feature maps (sum(Hi * Wi * A)).
Label values are in {-1, 0, ..., K}, with -1 means ignore, and K means background.
list[Tensor]:
i-th element is a Rx4 tensor, where R is the total number of anchors across
feature maps. The values are the matched gt boxes for each anchor.
Values are undefined for those anchors not labeled as foreground.
"""
anchors = Boxes.cat(anchors) # Rx4
gt_labels = []
matched_gt_boxes = []
for gt_per_image in instances:
match_quality_matrix = pairwise_iou(gt_per_image.gt_boxes, anchors)
matched_idxs, anchor_labels = self.max_matcher(match_quality_matrix)
del match_quality_matrix
if len(gt_per_image) > 0:
matched_gt_boxes_i = gt_per_image.gt_boxes.tensor[matched_idxs]
gt_labels_i = gt_per_image.gt_classes[matched_idxs]
# Anchors with label -1 are ignored.
gt_labels_i[anchor_labels == -1] = -1
else:
matched_gt_boxes_i = torch.zeros_like(anchors.tensor)
gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes
gt_labels.append(gt_labels_i)
matched_gt_boxes.append(matched_gt_boxes_i)
return gt_labels, matched_gt_boxes
def sample_background(
self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]]
):
pred_logits, pred_anchor_deltas = self._transpose_dense_predictions(
predictions, [self.num_classes, 4]
)
anchors = self.anchor_generator(features)
results: List[Instances] = []
for img_idx, image_size in enumerate(images.image_sizes):
scores_per_image = [x[img_idx].sigmoid() for x in pred_logits]
deltas_per_image = [x[img_idx] for x in pred_anchor_deltas]
results_per_image = self.sample_background_single_image(
anchors, scores_per_image, deltas_per_image, image_size
)
results.append(results_per_image)
return results
def sample_background_single_image(
self,
anchors: List[Boxes],
box_cls: List[Tensor],
box_delta: List[Tensor],
image_size: Tuple[int, int],
):
boxes_all = []
scores_all = []
# Iterate over every feature level
for box_cls_i, box_reg_i, anchors_i in zip(box_cls, box_delta, anchors):
# (HxWxAxK,)
predicted_prob = box_cls_i.max(dim=1).values
# 1. Keep boxes with confidence score lower than threshold
keep_idxs = predicted_prob < self.test_score_thresh
anchor_idxs = nonzero_tuple(keep_idxs)[0]
# 2. Random sample boxes
anchor_idxs = anchor_idxs[
random.sample(range(len(anchor_idxs)), k=min(len(anchor_idxs), self.max_samples_per_level))]
predicted_prob = predicted_prob[anchor_idxs]
anchors_i = anchors_i[anchor_idxs]
boxes_all.append(anchors_i.tensor)
scores_all.append(predicted_prob)
boxes_all, scores_all = [
cat(x) for x in [boxes_all, scores_all]
]
result = Instances(image_size)
result.pred_boxes = Boxes(boxes_all)
result.scores = 1. - scores_all # the confidence score to be background
result.pred_classes = torch.tensor([self.num_classes for _ in range(len(scores_all))])
return result
def visualize_training(self, batched_inputs, results, feedbacks=None):
"""
A function used to visualize ground truth images and final network predictions.
It shows ground truth bounding boxes on the original image and up to 20
predicted object bounding boxes on the original image.
Args:
batched_inputs (list): a list that contains input to the model.
results (List[Instances]): a list of #images elements returned by forward_inference().
"""
from detectron2.utils.visualizer import Visualizer
assert len(batched_inputs) == len(
results
), "Cannot visualize inputs and results of different sizes"
storage = get_event_storage()
max_boxes = 20
image_index = 0 # only visualize a single image
img = batched_inputs[image_index]["image"]
img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format)
v_gt = Visualizer(img, None)
v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes)
anno_img = v_gt.get_image()
processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1])
predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy()
v_pred = Visualizer(img, None)
v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes])
prop_img = v_pred.get_image()
num_classes = self.num_classes
if feedbacks is not None:
v_feedback_gt = Visualizer(img, None)
instance = feedbacks[0].to(torch.device("cpu"))
v_feedback_gt = v_feedback_gt.overlay_instances(
boxes=instance.proposal_boxes[instance.gt_classes != num_classes])
feedback_gt_img = v_feedback_gt.get_image()
v_feedback_gf = Visualizer(img, None)
v_feedback_gf = v_feedback_gf.overlay_instances(
boxes=instance.proposal_boxes[instance.gt_classes == num_classes])
feedback_gf_img = v_feedback_gf.get_image()
vis_img = np.vstack((anno_img, prop_img, feedback_gt_img, feedback_gf_img))
vis_img = vis_img.transpose(2, 0, 1)
vis_name = f"Top: GT; Middle: Pred; Bottom: Feedback GT, Feedback GF"
else:
vis_img = np.concatenate((anno_img, prop_img), axis=1)
vis_img = vis_img.transpose(2, 0, 1)
vis_name = "Left: GT bounding boxes; Right: Predicted proposals"
storage.put_image(vis_name, vis_img)
================================================
FILE: tllib/alignment/d_adapt/modeling/roi_heads/__init__.py
================================================
from .roi_heads import DecoupledRes5ROIHeads
================================================
FILE: tllib/alignment/d_adapt/modeling/roi_heads/fast_rcnn.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Dict
from detectron2.layers import cat
from detectron2.modeling.roi_heads.fast_rcnn import (
_log_classification_stats,
FastRCNNOutputLayers
)
from detectron2.structures import Instances
from tllib.modules.loss import LabelSmoothSoftmaxCEV1
import torch
def label_smoothing_cross_entropy(input, target, *, reduction="mean", **kwargs):
"""
Same as `tllib.modules.loss.LabelSmoothSoftmaxCEV1`, but returns 0 (instead of nan)
for empty inputs.
"""
if target.numel() == 0 and reduction == "mean":
return input.sum() * 0.0 # connect the gradient
return LabelSmoothSoftmaxCEV1(reduction=reduction, **kwargs)(input, target)
class DecoupledFastRCNNOutputLayers(FastRCNNOutputLayers):
"""
Two linear layers for predicting Fast R-CNN outputs:
1. proposal-to-detection box regression deltas
2. classification scores
Replace cross-entropy with label-smoothing cross-entropy
"""
def losses(self, predictions, proposals):
"""
Args:
predictions: return values of :meth:`forward()`.
proposals (list[Instances]): proposals that match the features that were used
to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``,
``gt_classes`` are expected.
Returns:
Dict[str, Tensor]: dict of losses
"""
scores, proposal_deltas = predictions
# parse classification outputs
gt_classes = (
cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0)
)
_log_classification_stats(scores, gt_classes)
# parse box regression outputs
if len(proposals):
proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4
assert not proposal_boxes.requires_grad, "Proposals should not require gradients!"
# If "gt_boxes" does not exist, the proposals must be all negative and
# should not be included in regression loss computation.
# Here we just use proposal_boxes as an arbitrary placeholder because its
# value won't be used in self.box_reg_loss().
gt_boxes = cat(
[(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals],
dim=0,
)
else:
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)
losses = {
"loss_cls": label_smoothing_cross_entropy(scores, gt_classes, reduction="mean"),
"loss_box_reg": self.box_reg_loss(
proposal_boxes, gt_boxes, proposal_deltas, gt_classes
),
}
return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()}
================================================
FILE: tllib/alignment/d_adapt/modeling/roi_heads/roi_heads.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
import numpy as np
import random
from typing import List, Tuple, Dict
from detectron2.structures import Boxes, Instances
from detectron2.utils.events import get_event_storage
from detectron2.layers import ShapeSpec, batched_nms
from detectron2.modeling.roi_heads import (
ROI_HEADS_REGISTRY,
Res5ROIHeads,
StandardROIHeads
)
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference
from detectron2.modeling.sampling import subsample_labels
from detectron2.layers import nonzero_tuple
from .fast_rcnn import DecoupledFastRCNNOutputLayers
@ROI_HEADS_REGISTRY.register()
class DecoupledRes5ROIHeads(Res5ROIHeads):
"""
The ROIHeads in a typical "C4" R-CNN model, where
the box and mask head share the cropping and
the per-region feature computation by a Res5 block.
It typically contains logic to
1. when training on labeled source domain, match proposals with ground truth and sample them
2. when training on unlabeled target domain, match proposals with feedbacks from adaptors and sample them
3. crop the regions and extract per-region features using proposals
4. make per-region predictions with different heads
"""
def __init__(self, *args, **kwargs):
super(DecoupledRes5ROIHeads, self).__init__(*args, **kwargs)
@classmethod
def from_config(cls, cfg, input_shape):
# fmt: off
ret = super().from_config(cfg, input_shape)
ret["res5"], out_channels = cls._build_res5_block(cfg)
box_predictor = DecoupledFastRCNNOutputLayers(cfg, ShapeSpec(channels=out_channels, height=1, width=1))
ret["box_predictor"] = box_predictor
return ret
def forward(self, images, features, proposals, targets=None, feedbacks=None, labeled=True):
"""
Prepare some proposals to be used to train the ROI heads.
When training on labeled source domain, it performs box matching between `proposals` and `targets`, and assigns
training labels to the proposals.
When training on unlabeled target domain, it performs box matching between `proposals` and `feedbacks`, and assigns
training labels to the proposals.
It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth
boxes, with a fraction of positives that is no larger than
``self.positive_fraction``.
Args:
images (ImageList):
features (dict[str,Tensor]): input data as a mapping from feature
map name to tensor. Axis 0 represents the number of images `N` in
the input data; axes 1-3 are channels, height, and width, which may
vary between feature maps (e.g., if a feature pyramid is used).
proposals (list[Instances]): length `N` list of `Instances`. The i-th
`Instances` contains object proposals for the i-th input image,
with fields "proposal_boxes" and "objectness_logits".
targets (list[Instances], optional): length `N` list of `Instances`. The i-th
`Instances` contains the ground-truth per-instance annotations
for the i-th input image. Specify `targets` during training only.
It may have the following fields:
- gt_boxes: the bounding box of each instance.
- gt_classes: the label for each instance with a category ranging in [0, #class].
- gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.
- gt_keypoints: NxKx3, the groud-truth keypoints for each instance.
feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th
`Instances` contains the feedback of per-instance annotations
for the i-th input image. Specify `feedbacks` during training only.
It have the same fields as `targets`.
labeled (bool, optional): whether has ground-truth label
Returns:
tuple[list[Instances], list[Instances], dict]:
a tuple containing foreground proposals (`Instances`), background proposals (`Instances`) and a dict of different losses.
Each `Instances` has the following fields:
- proposal_boxes: the proposal boxes
- gt_boxes: the ground-truth box that the proposal is assigned to
(this is only meaningful if the proposal has a label > 0; if label = 0
then the ground-truth box is random)
Other fields such as "gt_classes", "gt_masks", that's included in `targets`.
"""
del images
if self.training:
assert targets
if labeled:
proposals = self.label_and_sample_proposals(proposals, targets)
else:
proposals = self.label_and_sample_feedbacks(feedbacks)
del targets
proposal_boxes = [x.proposal_boxes for x in proposals]
box_features = self._shared_roi_transform(
[features[f] for f in self.in_features], proposal_boxes
)
predictions = self.box_predictor(box_features.mean(dim=[2, 3]))
if self.training:
del features
losses = self.box_predictor.losses(predictions, proposals)
if not labeled:
losses.pop("loss_box_reg")
return [], [], losses
else:
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
boxes = self.box_predictor.predict_boxes(predictions, proposals)
scores = self.box_predictor.predict_probs(predictions, proposals)
image_shapes = [x.image_size for x in proposals]
pred_instances, _ = fast_rcnn_inference(
boxes,
scores,
image_shapes,
self.box_predictor.test_score_thresh,
self.box_predictor.test_nms_thresh,
self.box_predictor.test_topk_per_image,
)
background_instances, _ = fast_rcnn_sample_background(
[box.tensor for box in proposal_boxes],
scores,
image_shapes,
self.box_predictor.test_score_thresh,
self.box_predictor.test_nms_thresh,
self.box_predictor.test_topk_per_image,
)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
background_instances = self.forward_with_given_boxes(features, background_instances)
return pred_instances, background_instances, {}
@torch.no_grad()
def label_and_sample_feedbacks(
self, feedbacks, batch_size_per_image=256
) -> List[Instances]:
"""
Prepare some proposals to be used to train the ROI heads.
It performs box matching between `proposals` and `feedbacks`, and assigns
training labels to the proposals.
It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth
boxes, with a fraction of positives that is no larger than
``self.positive_fraction``.
Args:
feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th
`Instances` contains the feedback of per-instance annotations
for the i-th input image. Specify `feedbacks` during training only.
It have the same fields as `targets`.
Returns:
list[Instances]:
length `N` list of `Instances`s containing the proposals
sampled for training. Each `Instances` has the following fields:
- proposal_boxes: the proposal boxes
- gt_boxes: the ground-truth box that the proposal is assigned to
(this is only meaningful if the proposal has a label > 0; if label = 0
then the ground-truth box is random)
Other fields such as "gt_classes", "gt_masks", that's included in `targets`.
"""
proposals_with_gt = []
num_fg_samples = []
num_bg_samples = []
for feedbacks_per_image in feedbacks:
gt_classes = feedbacks_per_image.gt_classes
positive = nonzero_tuple((gt_classes != -1) & (gt_classes != self.num_classes))[0]
# ensure each batch consists the same number bg and fg boxes
batch_size = min(batch_size_per_image, max(2 * positive.numel(), 1))
sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
gt_classes, batch_size, self.positive_fraction, self.num_classes
)
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
gt_classes = gt_classes[sampled_idxs]
# Set target attributes of the sampled proposals:
proposals_per_image = feedbacks_per_image[sampled_idxs]
proposals_per_image.gt_classes = gt_classes
num_bg_samples.append((gt_classes == self.num_classes).sum().item())
num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])
proposals_with_gt.append(proposals_per_image)
# Log the number of fg/bg samples that are selected for training ROI heads
storage = get_event_storage()
storage.put_scalar("roi_head_pseudo/num_fg_samples", np.mean(num_fg_samples))
storage.put_scalar("roi_head_pseudo/num_bg_samples", np.mean(num_bg_samples))
return proposals_with_gt
@ROI_HEADS_REGISTRY.register()
class DecoupledStandardROIHeads(StandardROIHeads):
"""
The Standard ROIHeads used by most models, such as FPN and C5.
It's "standard" in a sense that there is no ROI transform sharing
or feature sharing between tasks.
Each head independently processes the input features by each head's
own pooler and head.
It typically contains logic to
1. when training on labeled source domain, match proposals with ground truth and sample them
2. when training on unlabeled target domain, match proposals with feedbacks from adaptors and sample them
3. crop the regions and extract per-region features using proposals
4. make per-region predictions with different heads
"""
def __init__(self, *args, **kwargs):
super(DecoupledStandardROIHeads, self).__init__(*args, **kwargs)
@classmethod
def from_config(cls, cfg, input_shape):
# fmt: off
ret = super().from_config(cfg, input_shape)
box_predictor = DecoupledFastRCNNOutputLayers(cfg, ret['box_head'].output_shape)
ret["box_predictor"] = box_predictor
return ret
def forward(self, images, features, proposals, targets=None, feedbacks=None, labeled=True):
"""
Prepare some proposals to be used to train the ROI heads.
When training on labeled source domain, it performs box matching between `proposals` and `targets`, and assigns
training labels to the proposals.
When training on unlabeled target domain, it performs box matching between `proposals` and `feedbacks`, and assigns
training labels to the proposals.
It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth
boxes, with a fraction of positives that is no larger than
``self.positive_fraction``.
Args:
images (ImageList):
features (dict[str,Tensor]): input data as a mapping from feature
map name to tensor. Axis 0 represents the number of images `N` in
the input data; axes 1-3 are channels, height, and width, which may
vary between feature maps (e.g., if a feature pyramid is used).
proposals (list[Instances]): length `N` list of `Instances`. The i-th
`Instances` contains object proposals for the i-th input image,
with fields "proposal_boxes" and "objectness_logits".
targets (list[Instances], optional): length `N` list of `Instances`. The i-th
`Instances` contains the ground-truth per-instance annotations
for the i-th input image. Specify `targets` during training only.
It may have the following fields:
- gt_boxes: the bounding box of each instance.
- gt_classes: the label for each instance with a category ranging in [0, #class].
- gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.
- gt_keypoints: NxKx3, the groud-truth keypoints for each instance.
feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th
`Instances` contains the feedback of per-instance annotations
for the i-th input image. Specify `feedbacks` during training only.
It have the same fields as `targets`.
labeled (bool, optional): whether has ground-truth label
Returns:
tuple[list[Instances], list[Instances], dict]:
a tuple containing foreground proposals (`Instances`), background proposals (`Instances`) and a dict of different losses.
Each `Instances` has the following fields:
- proposal_boxes: the proposal boxes
- gt_boxes: the ground-truth box that the proposal is assigned to
(this is only meaningful if the proposal has a label > 0; if label = 0
then the ground-truth box is random)
Other fields such as "gt_classes", "gt_masks", that's included in `targets`.
"""
del images
if self.training:
assert targets
if labeled:
proposals = self.label_and_sample_proposals(proposals, targets)
else:
proposals = self.label_and_sample_feedbacks(feedbacks)
del targets
if self.training:
losses = self._forward_box(features, proposals)
# Usually the original proposals used by the box head are used by the mask, keypoint
# heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes
# predicted by the box head.
losses.update(self._forward_mask(features, proposals))
losses.update(self._forward_keypoint(features, proposals))
if not labeled:
losses.pop('loss_box_reg')
return [], [], losses
else:
pred_instances, predictions = self._forward_box(features, proposals)
scores = self.box_predictor.predict_probs(predictions, proposals)
image_shapes = [x.image_size for x in proposals]
proposal_boxes = [x.proposal_boxes for x in proposals]
background_instances, _ = fast_rcnn_sample_background(
[box.tensor for box in proposal_boxes],
scores,
image_shapes,
self.box_predictor.test_score_thresh,
self.box_predictor.test_nms_thresh,
self.box_predictor.test_topk_per_image,
)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
background_instances = self.forward_with_given_boxes(features, background_instances)
return pred_instances, background_instances, {}
def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]):
"""
Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`,
the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument.
Args:
features (dict[str, Tensor]): mapping from feature map names to tensor.
Same as in :meth:`ROIHeads.forward`.
proposals (list[Instances]): the per-image object proposals with
their matching ground truth.
Each has fields "proposal_boxes", and "objectness_logits",
"gt_classes", "gt_boxes".
Returns:
In training, a dict of losses.
In inference, a list of `Instances`, the predicted instances.
"""
features = [features[f] for f in self.box_in_features]
box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])
box_features = self.box_head(box_features)
predictions = self.box_predictor(box_features)
del box_features
if self.training:
losses = self.box_predictor.losses(predictions, proposals)
# proposals is modified in-place below, so losses must be computed first.
if self.train_on_pred_boxes:
with torch.no_grad():
pred_boxes = self.box_predictor.predict_boxes_for_gt_classes(
predictions, proposals
)
for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes):
proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image)
return losses
else:
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
return pred_instances, predictions
@torch.no_grad()
def label_and_sample_feedbacks(
self, feedbacks, batch_size_per_image=256
) -> List[Instances]:
"""
Prepare some proposals to be used to train the ROI heads.
It performs box matching between `proposals` and `targets`, and assigns
training labels to the proposals.
It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth
boxes, with a fraction of positives that is no larger than
``self.positive_fraction``.
Args:
See :meth:`ROIHeads.forward`
Returns:
list[Instances]:
length `N` list of `Instances`s containing the proposals
sampled for training. Each `Instances` has the following fields:
- proposal_boxes: the proposal boxes
- gt_boxes: the ground-truth box that the proposal is assigned to
(this is only meaningful if the proposal has a label > 0; if label = 0
then the ground-truth box is random)
Other fields such as "gt_classes", "gt_masks", that's included in `targets`.
"""
proposals_with_gt = []
num_fg_samples = []
num_bg_samples = []
for feedbacks_per_image in feedbacks:
gt_classes = feedbacks_per_image.gt_classes
positive = nonzero_tuple((gt_classes != -1) & (gt_classes != self.num_classes))[0]
# ensure each batch consists the same number bg and fg boxes
batch_size = min(batch_size_per_image, max(2 * positive.numel(), 1))
sampled_fg_idxs, sampled_bg_idxs = subsample_labels(
gt_classes, batch_size, self.positive_fraction, self.num_classes
)
sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0)
gt_classes = gt_classes[sampled_idxs]
# Set target attributes of the sampled proposals:
proposals_per_image = feedbacks_per_image[sampled_idxs]
proposals_per_image.gt_classes = gt_classes
num_bg_samples.append((gt_classes == self.num_classes).sum().item())
num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1])
proposals_with_gt.append(proposals_per_image)
# Log the number of fg/bg samples that are selected for training ROI heads
storage = get_event_storage()
storage.put_scalar("roi_head_pseudo/num_fg_samples", np.mean(num_fg_samples))
storage.put_scalar("roi_head_pseudo/num_bg_samples", np.mean(num_bg_samples))
return proposals_with_gt
def fast_rcnn_sample_background(
boxes: List[torch.Tensor],
scores: List[torch.Tensor],
image_shapes: List[Tuple[int, int]],
score_thresh: float,
nms_thresh: float,
topk_per_image: int,
):
"""
Call `fast_rcnn_sample_background_single_image` for all images.
Args:
boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic
boxes for each image. Element i has shape (Ri, K * 4) if doing
class-specific regression, or (Ri, 4) if doing class-agnostic
regression, where Ri is the number of predicted objects for image i.
This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`.
scores (list[Tensor]): A list of Tensors of predicted class scores for each image.
Element i has shape (Ri, K + 1), where Ri is the number of predicted objects
for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`.
image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch.
score_thresh (float): Only return detections with a confidence score exceeding this
threshold.
nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1].
topk_per_image (int): The number of top scoring detections to return. Set < 0 to return
all detections.
Returns:
instances: (list[Instances]): A list of N instances, one for each image in the batch,
that stores the background proposals.
kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates
the corresponding boxes/scores index in [0, Ri) from the input, for image i.
"""
result_per_image = [
fast_rcnn_sample_background_single_image(
boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image
)
for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes)
]
return [x[0] for x in result_per_image], [x[1] for x in result_per_image]
def fast_rcnn_sample_background_single_image(
boxes,
scores,
image_shape: Tuple[int, int],
score_thresh: float,
nms_thresh: float,
topk_per_image: int,
):
"""
Single-image background samples. .
Args:
Same as `fast_rcnn_sample_background`, but with boxes, scores, and image shapes
per image.
Returns:
Same as `fast_rcnn_sample_background`, but for only one image.
"""
valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1)
if not valid_mask.all():
boxes = boxes[valid_mask]
scores = scores[valid_mask]
num_classes = scores.shape[1]
# Only keep background proposals
scores = scores[:, -1:]
# Convert to Boxes to use the `clip` function ...
boxes = Boxes(boxes.reshape(-1, 4))
boxes.clip(image_shape)
boxes = boxes.tensor.view(-1, 1, 4) # R x C x 4
# 1. Filter results based on detection scores. It can make NMS more efficient
# by filtering out low-confidence detections.
filter_mask = scores > score_thresh # R
# R' x 2. First column contains indices of the R predictions;
# Second column contains indices of classes.
filter_inds = filter_mask.nonzero()
boxes = boxes[filter_mask]
scores = scores[filter_mask]
# 2. Apply NMS only for background class
keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh)
if 0 <= topk_per_image < len(keep):
idx = list(range(len(keep)))
idx = random.sample(idx, k=topk_per_image)
idx = sorted(idx)
keep = keep[idx]
boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep]
result = Instances(image_shape)
result.pred_boxes = Boxes(boxes)
result.scores = scores
result.pred_classes = filter_inds[:, 1] + num_classes - 1
return result, filter_inds[:, 0]
================================================
FILE: tllib/alignment/d_adapt/proposal.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
import copy
import numpy as np
import os
import json
from typing import Optional, Callable, List
import random
import pprint
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader
from torchvision.transforms.functional import crop
from detectron2.structures import pairwise_iou
from detectron2.evaluation.evaluator import DatasetEvaluator
from detectron2.data.dataset_mapper import DatasetMapper
import detectron2.data.detection_utils as utils
import detectron2.data.transforms as T
class ProposalMapper(DatasetMapper):
"""
A callable which takes a dataset dict in Detectron2 Dataset format,
and map it into a format used by the model.
This is the default callable to be used to map your dataset dict into training data.
You may need to follow it to implement your own one for customized logic,
such as a different way to read or transform images.
See :doc:`/tutorials/data_loading` for details.
The callable currently does the following:
1. Read the image from "file_name"
2. Prepare data and annotations to Tensor and :class:`Instances`
"""
def __call__(self, dataset_dict):
"""
Args:
dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.
Returns:
dict: a format that builtin models in detectron2 accept
"""
dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below
# USER: Write your own image loading if it's not from a file
image = utils.read_image(dataset_dict["file_name"], format=self.image_format)
utils.check_image_size(dataset_dict, image)
origin_image_shape = image.shape[:2] # h, w
aug_input = T.AugInput(image)
image = aug_input.image
# Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
# but not efficient on large generic data structures due to the use of pickle & mp.Queue.
# Therefore it's important to use torch.Tensor.
dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
if "annotations" in dataset_dict:
# USER: Modify this if you want to keep them for some reason.
for anno in dataset_dict["annotations"]:
if not self.use_instance_mask:
anno.pop("segmentation", None)
if not self.use_keypoint:
anno.pop("keypoints", None)
# USER: Implement additional transformations if you have other types of data
annos = [
obj
for obj in dataset_dict.pop("annotations")
if obj.get("iscrowd", 0) == 0
]
instances = utils.annotations_to_instances(
annos, origin_image_shape, mask_format=self.instance_mask_format
)
# After transforms such as cropping are applied, the bounding box may no longer
# tightly bound the object. As an example, imagine a triangle object
# [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight
# bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to
# the intersection of original bounding box and the cropping box.
if self.recompute_boxes:
instances.gt_boxes = instances.gt_masks.get_bounding_boxes()
dataset_dict["instances"] = utils.filter_empty_instances(instances)
return dataset_dict
class ProposalGenerator(DatasetEvaluator):
"""
The function :func:`inference_on_dataset` runs the model over
all samples in the dataset, and have a ProposalGenerator to generate proposals for each inputs/outputs.
This class will accumulate information of the inputs/outputs (by :meth:`process`),
and generate proposals results in the end (by :meth:`evaluate`).
"""
def __init__(self, iou_threshold=(0.4, 0.5), num_classes=20, *args, **kwargs):
super(ProposalGenerator, self).__init__(*args, **kwargs)
self.fg_proposal_list = []
self.bg_proposal_list = []
self.iou_threshold = iou_threshold
self.num_classes = num_classes
def process_type(self, inputs, outputs, type='instances'):
cpu_device = torch.device('cpu')
input_instance = inputs[0]['instances'].to(cpu_device)
output_instance = outputs[0][type].to(cpu_device)
filename = inputs[0]['file_name']
pred_boxes = output_instance.pred_boxes
pred_scores = output_instance.scores
pred_classes = output_instance.pred_classes
proposal = Proposal(
image_id=inputs[0]['image_id'],
filename=filename,
pred_boxes=pred_boxes.tensor.numpy(),
pred_classes=pred_classes.numpy(),
pred_scores=pred_scores.numpy(),
)
if hasattr(input_instance, 'gt_boxes'):
gt_boxes = input_instance.gt_boxes
# assign a gt label for each pred_box
if pred_boxes.tensor.shape[0] == 0:
proposal.gt_fg_classes = proposal.gt_classes = proposal.gt_ious = proposal.gt_boxes = np.array([])
elif gt_boxes.tensor.shape[0] == 0:
proposal.gt_fg_classes = proposal.gt_classes = np.array([self.num_classes for _ in range(pred_boxes.tensor.shape[0])])
proposal.gt_ious = np.array([0. for _ in range(pred_boxes.tensor.shape[0])])
proposal.gt_boxes = np.array([[0, 0, 0, 0] for _ in range(pred_boxes.tensor.shape[0])])
else:
gt_ious, gt_classes_idx = pairwise_iou(pred_boxes, gt_boxes).max(dim=1)
gt_classes = input_instance.gt_classes[gt_classes_idx]
proposal.gt_fg_classes = copy.deepcopy(gt_classes.numpy())
gt_classes[gt_ious <= self.iou_threshold[0]] = self.num_classes # background classes
gt_classes[(self.iou_threshold[0] < gt_ious) & (gt_ious <= self.iou_threshold[1])] = -1 # ignore
proposal.gt_classes = gt_classes.numpy()
proposal.gt_ious = gt_ious.numpy()
proposal.gt_boxes = input_instance.gt_boxes[gt_classes_idx].tensor.numpy()
return proposal
def process(self, inputs, outputs):
self.fg_proposal_list.append(self.process_type(inputs, outputs, "instances"))
self.bg_proposal_list.append(self.process_type(inputs, outputs, "background"))
def evaluate(self):
return self.fg_proposal_list, self.bg_proposal_list
class Proposal:
"""
A data structure that stores the proposals for a single image.
Args:
image_id (str): unique image identifier
filename (str): image filename
pred_boxes (numpy.ndarray): predicted boxes
pred_classes (numpy.ndarray): predicted classes
pred_scores (numpy.ndarray): class confidence score
gt_classes (numpy.ndarray, optional): ground-truth classes, including background classes
gt_boxes (numpy.ndarray, optional): ground-truth boxes
gt_ious (numpy.ndarray, optional): IoU between predicted boxes and ground-truth boxes
gt_fg_classes (numpy.ndarray, optional): ground-truth foreground classes, not including background classes
"""
def __init__(self, image_id, filename, pred_boxes, pred_classes, pred_scores,
gt_classes=None, gt_boxes=None, gt_ious=None, gt_fg_classes=None):
self.image_id = image_id
self.filename = filename
self.pred_boxes = pred_boxes
self.pred_classes = pred_classes
self.pred_scores = pred_scores
self.gt_classes = gt_classes
self.gt_boxes = gt_boxes
self.gt_ious = gt_ious
self.gt_fg_classes = gt_fg_classes
def to_dict(self):
return {
"__proposal__": True,
"image_id": self.image_id,
"filename": self.filename,
"pred_boxes": self.pred_boxes.tolist(),
"pred_classes": self.pred_classes.tolist(),
"pred_scores": self.pred_scores.tolist(),
"gt_classes": self.gt_classes.tolist(),
"gt_boxes": self.gt_boxes.tolist(),
"gt_ious": self.gt_ious.tolist(),
"gt_fg_classes": self.gt_fg_classes.tolist()
}
def __str__(self):
pp = pprint.PrettyPrinter(indent=2)
return pp.pformat(self.to_dict())
def __len__(self):
return len(self.pred_boxes)
def __getitem__(self, item):
return Proposal(
image_id=self.image_id,
filename=self.filename,
pred_boxes=self.pred_boxes[item],
pred_classes=self.pred_classes[item],
pred_scores=self.pred_scores[item],
gt_classes=self.gt_classes[item],
gt_boxes=self.gt_boxes[item],
gt_ious=self.gt_ious[item],
gt_fg_classes=self.gt_fg_classes[item]
)
class ProposalEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Proposal):
return obj.to_dict()
return json.JSONEncoder.default(self, obj)
def asProposal(dict):
if '__proposal__' in dict:
return Proposal(
dict["image_id"],
dict["filename"],
np.array(dict["pred_boxes"]),
np.array(dict["pred_classes"]),
np.array(dict["pred_scores"]),
np.array(dict["gt_classes"]),
np.array(dict["gt_boxes"]),
np.array(dict["gt_ious"]),
np.array(dict["gt_fg_classes"])
)
return dict
class PersistentProposalList(list):
"""
A data structure that stores the proposals for a dataset.
Args:
filename (str, optional): filename indicating where to cache
"""
def __init__(self, filename=None):
super(PersistentProposalList, self).__init__()
self.filename = filename
def load(self):
"""
Load from cache.
Return:
whether succeed
"""
if os.path.exists(self.filename):
print("Reading from cache: {}".format(self.filename))
with open(self.filename, "r") as f:
self.extend(json.load(f, object_hook=asProposal))
return True
else:
return False
def flush(self):
"""
Flush to cache.
"""
os.makedirs(os.path.dirname(self.filename), exist_ok=True)
with open(self.filename, "w") as f:
json.dump(self, f, cls=ProposalEncoder)
print("Write to cache: {}".format(self.filename))
def flatten(proposal_list, max_number=10000):
"""
Flatten a list of proposals
Args:
proposal_list (list): a list of proposals grouped by images
max_number (int): maximum number of kept proposals for each image
"""
flattened_list = []
for proposals in proposal_list:
for i in range(min(len(proposals), max_number)):
flattened_list.append(proposals[i:i+1])
return flattened_list
class ProposalDataset(datasets.VisionDataset):
"""
A dataset for proposals.
Args:
proposal_list (list): list of Proposal
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
crop_func: (ExpandCrop, optional):
"""
def __init__(self, proposal_list: List[Proposal], transform: Optional[Callable] = None, crop_func=None):
super(ProposalDataset, self).__init__("", transform=transform)
self.proposal_list = list(filter(lambda p: len(p) > 0, proposal_list)) # remove images without proposals
self.loader = default_loader
self.crop_func = crop_func
def __getitem__(self, index: int):
# get proposals for the index-th image
proposals = self.proposal_list[index]
img = self.loader(proposals.filename)
# random sample a proposal
proposal = proposals[random.randint(0, len(proposals)-1)]
image_width, image_height = img.width, img.height
# proposal_dict = proposal.to_dict()
# proposal_dict.update(width=img.width, height=img.height)
# crop the proposal from the whole image
x1, y1, x2, y2 = proposal.pred_boxes
top, left, height, width = int(y1), int(x1), int(y2 - y1), int(x2 - x1)
if self.crop_func is not None:
top, left, height, width = self.crop_func(img, top, left, height, width)
img = crop(img, top, left, height, width)
if self.transform is not None:
img = self.transform(img)
return img, {
"image_id": proposal.image_id,
"filename": proposal.filename,
"pred_boxes": proposal.pred_boxes.astype(np.float),
"pred_classes": proposal.pred_classes.astype(np.long),
"pred_scores": proposal.pred_scores.astype(np.float),
"gt_classes": proposal.gt_classes.astype(np.long),
"gt_boxes": proposal.gt_boxes.astype(np.float),
"gt_ious": proposal.gt_ious.astype(np.float),
"gt_fg_classes": proposal.gt_fg_classes.astype(np.long),
"width": image_width,
"height": image_height
}
def __len__(self):
return len(self.proposal_list)
class ExpandCrop:
"""
The input of the bounding box adaptor (the crops of objects) will be larger than the original
predicted box, so that the bounding box adapter could access more location information.
"""
def __init__(self, expand=1.):
self.expand = expand
def __call__(self, img, top, left, height, width):
cx = left + width / 2.
cy = top + height / 2.
height = round(height * self.expand)
width = round(width * self.expand)
new_top = round(cy - height / 2.)
new_left = round(cx - width / 2.)
return new_top, new_left, height, width
================================================
FILE: tllib/alignment/dan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, Sequence
import torch
import torch.nn as nn
from tllib.modules.classifier import Classifier as ClassifierBase
__all__ = ['MultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier']
class MultipleKernelMaximumMeanDiscrepancy(nn.Module):
r"""The Multiple Kernel Maximum Mean Discrepancy (MK-MMD) used in
`Learning Transferable Features with Deep Adaptation Networks (ICML 2015) `_
Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t`
of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate
activations as :math:`\{z_i^s\}_{i=1}^{n_s}` and :math:`\{z_i^t\}_{i=1}^{n_t}`.
The MK-MMD :math:`D_k (P, Q)` between probability distributions P and Q is defined as
.. math::
D_k(P, Q) \triangleq \| E_p [\phi(z^s)] - E_q [\phi(z^t)] \|^2_{\mathcal{H}_k},
:math:`k` is a kernel function in the function space
.. math::
\mathcal{K} \triangleq \{ k=\sum_{u=1}^{m}\beta_{u} k_{u} \}
where :math:`k_{u}` is a single kernel.
Using kernel trick, MK-MMD can be computed as
.. math::
\hat{D}_k(P, Q) &=
\dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} k(z_i^{s}, z_j^{s})\\
&+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} k(z_i^{t}, z_j^{t})\\
&- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} k(z_i^{s}, z_j^{t}).\\
Args:
kernels (tuple(torch.nn.Module)): kernel functions.
linear (bool): whether use the linear version of DAN. Default: False
Inputs:
- z_s (tensor): activations from the source domain, :math:`z^s`
- z_t (tensor): activations from the target domain, :math:`z^t`
Shape:
- Inputs: :math:`(minibatch, *)` where * means any dimension
- Outputs: scalar
.. note::
Activations :math:`z^{s}` and :math:`z^{t}` must have the same shape.
.. note::
The kernel values will add up when there are multiple kernels.
Examples::
>>> from tllib.modules.kernels import GaussianKernel
>>> feature_dim = 1024
>>> batch_size = 10
>>> kernels = (GaussianKernel(alpha=0.5), GaussianKernel(alpha=1.), GaussianKernel(alpha=2.))
>>> loss = MultipleKernelMaximumMeanDiscrepancy(kernels)
>>> # features from source domain and target domain
>>> z_s, z_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> output = loss(z_s, z_t)
"""
def __init__(self, kernels: Sequence[nn.Module], linear: Optional[bool] = False):
super(MultipleKernelMaximumMeanDiscrepancy, self).__init__()
self.kernels = kernels
self.index_matrix = None
self.linear = linear
def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:
features = torch.cat([z_s, z_t], dim=0)
batch_size = int(z_s.size(0))
self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s.device)
kernel_matrix = sum([kernel(features) for kernel in self.kernels]) # Add up the matrix of each kernel
# Add 2 / (n-1) to make up for the value on the diagonal
# to ensure loss is positive in the non-linear version
loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)
return loss
def _update_index_matrix(batch_size: int, index_matrix: Optional[torch.Tensor] = None,
linear: Optional[bool] = True) -> torch.Tensor:
r"""
Update the `index_matrix` which convert `kernel_matrix` to loss.
If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`.
Else return a new tensor with shape (2 x batch_size, 2 x batch_size).
"""
if index_matrix is None or index_matrix.size(0) != batch_size * 2:
index_matrix = torch.zeros(2 * batch_size, 2 * batch_size)
if linear:
for i in range(batch_size):
s1, s2 = i, (i + 1) % batch_size
t1, t2 = s1 + batch_size, s2 + batch_size
index_matrix[s1, s2] = 1. / float(batch_size)
index_matrix[t1, t2] = 1. / float(batch_size)
index_matrix[s1, t2] = -1. / float(batch_size)
index_matrix[s2, t1] = -1. / float(batch_size)
else:
for i in range(batch_size):
for j in range(batch_size):
if i != j:
index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1))
index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1))
for i in range(batch_size):
for j in range(batch_size):
index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size)
index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size)
return index_matrix
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
================================================
FILE: tllib/alignment/dann.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.grl import WarmStartGradientReverseLayer
from tllib.modules.classifier import Classifier as ClassifierBase
from tllib.utils.metric import binary_accuracy, accuracy
__all__ = ['DomainAdversarialLoss']
class DomainAdversarialLoss(nn.Module):
r"""
The Domain Adversarial Loss proposed in
`Domain-Adversarial Training of Neural Networks (ICML 2015) `_
Domain adversarial loss measures the domain discrepancy through training a domain discriminator.
Given domain discriminator :math:`D`, feature representation :math:`f`, the definition of DANN loss is
.. math::
loss(\mathcal{D}_s, \mathcal{D}_t) = \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(f_i^s)]
+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(f_j^t)].
Args:
domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1)
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
grl (WarmStartGradientReverseLayer, optional): Default: None.
Inputs:
- f_s (tensor): feature representations on source domain, :math:`f^s`
- f_t (tensor): feature representations on target domain, :math:`f^t`
- w_s (tensor, optional): a rescaling weight given to each instance from source domain.
- w_t (tensor, optional): a rescaling weight given to each instance from target domain.
Shape:
- f_s, f_t: :math:`(N, F)` where F means the dimension of input features.
- Outputs: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, )`.
Examples::
>>> from tllib.modules.domain_discriminator import DomainDiscriminator
>>> discriminator = DomainDiscriminator(in_feature=1024, hidden_size=1024)
>>> loss = DomainAdversarialLoss(discriminator, reduction='mean')
>>> # features from source domain and target domain
>>> f_s, f_t = torch.randn(20, 1024), torch.randn(20, 1024)
>>> # If you want to assign different weights to each instance, you should pass in w_s and w_t
>>> w_s, w_t = torch.randn(20), torch.randn(20)
>>> output = loss(f_s, f_t, w_s, w_t)
"""
def __init__(self, domain_discriminator: nn.Module, reduction: Optional[str] = 'mean',
grl: Optional = None, sigmoid=True):
super(DomainAdversarialLoss, self).__init__()
self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True) if grl is None else grl
self.domain_discriminator = domain_discriminator
self.sigmoid = sigmoid
self.reduction = reduction
self.bce = lambda input, target, weight: \
F.binary_cross_entropy(input, target, weight=weight, reduction=reduction)
self.domain_discriminator_accuracy = None
def forward(self, f_s: torch.Tensor, f_t: torch.Tensor,
w_s: Optional[torch.Tensor] = None, w_t: Optional[torch.Tensor] = None) -> torch.Tensor:
f = self.grl(torch.cat((f_s, f_t), dim=0))
d = self.domain_discriminator(f)
if self.sigmoid:
d_s, d_t = d.chunk(2, dim=0)
d_label_s = torch.ones((f_s.size(0), 1)).to(f_s.device)
d_label_t = torch.zeros((f_t.size(0), 1)).to(f_t.device)
self.domain_discriminator_accuracy = 0.5 * (
binary_accuracy(d_s, d_label_s) + binary_accuracy(d_t, d_label_t))
if w_s is None:
w_s = torch.ones_like(d_label_s)
if w_t is None:
w_t = torch.ones_like(d_label_t)
return 0.5 * (
F.binary_cross_entropy(d_s, d_label_s, weight=w_s.view_as(d_s), reduction=self.reduction) +
F.binary_cross_entropy(d_t, d_label_t, weight=w_t.view_as(d_t), reduction=self.reduction)
)
else:
d_label = torch.cat((
torch.ones((f_s.size(0),)).to(f_s.device),
torch.zeros((f_t.size(0),)).to(f_t.device),
)).long()
if w_s is None:
w_s = torch.ones((f_s.size(0),)).to(f_s.device)
if w_t is None:
w_t = torch.ones((f_t.size(0),)).to(f_t.device)
self.domain_discriminator_accuracy = accuracy(d, d_label)
loss = F.cross_entropy(d, d_label, reduction='none') * torch.cat([w_s, w_t], dim=0)
if self.reduction == "mean":
return loss.mean()
elif self.reduction == "sum":
return loss.sum()
elif self.reduction == "none":
return loss
else:
raise NotImplementedError(self.reduction)
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
================================================
FILE: tllib/alignment/jan.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, Sequence
import torch
import torch.nn as nn
from tllib.modules.classifier import Classifier as ClassifierBase
from tllib.modules.grl import GradientReverseLayer
from tllib.modules.kernels import GaussianKernel
from tllib.alignment.dan import _update_index_matrix
__all__ = ['JointMultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier']
class JointMultipleKernelMaximumMeanDiscrepancy(nn.Module):
r"""The Joint Multiple Kernel Maximum Mean Discrepancy (JMMD) used in
`Deep Transfer Learning with Joint Adaptation Networks (ICML 2017) `_
Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t`
of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate
activations in layers :math:`\mathcal{L}` as :math:`\{(z_i^{s1}, ..., z_i^{s|\mathcal{L}|})\}_{i=1}^{n_s}` and
:math:`\{(z_i^{t1}, ..., z_i^{t|\mathcal{L}|})\}_{i=1}^{n_t}`. The empirical estimate of
:math:`\hat{D}_{\mathcal{L}}(P, Q)` is computed as the squared distance between the empirical kernel mean
embeddings as
.. math::
\hat{D}_{\mathcal{L}}(P, Q) &=
\dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{sl}) \\
&+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{tl}, z_j^{tl}) \\
&- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{tl}). \\
Args:
kernels (tuple(tuple(torch.nn.Module))): kernel functions, where `kernels[r]` corresponds to kernel :math:`k^{\mathcal{L}[r]}`.
linear (bool): whether use the linear version of JAN. Default: False
thetas (list(Theta): use adversarial version JAN if not None. Default: None
Inputs:
- z_s (tuple(tensor)): multiple layers' activations from the source domain, :math:`z^s`
- z_t (tuple(tensor)): multiple layers' activations from the target domain, :math:`z^t`
Shape:
- :math:`z^{sl}` and :math:`z^{tl}`: :math:`(minibatch, *)` where * means any dimension
- Outputs: scalar
.. note::
Activations :math:`z^{sl}` and :math:`z^{tl}` must have the same shape.
.. note::
The kernel values will add up when there are multiple kernels for a certain layer.
Examples::
>>> feature_dim = 1024
>>> batch_size = 10
>>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.))
>>> layer2_kernels = (GaussianKernel(1.), )
>>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels))
>>> # layer1 features from source domain and target domain
>>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> # layer2 features from source domain and target domain
>>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> output = loss((z1_s, z2_s), (z1_t, z2_t))
"""
def __init__(self, kernels: Sequence[Sequence[nn.Module]], linear: Optional[bool] = True, thetas: Sequence[nn.Module] = None):
super(JointMultipleKernelMaximumMeanDiscrepancy, self).__init__()
self.kernels = kernels
self.index_matrix = None
self.linear = linear
if thetas:
self.thetas = thetas
else:
self.thetas = [nn.Identity() for _ in kernels]
def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor:
batch_size = int(z_s[0].size(0))
self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s[0].device)
kernel_matrix = torch.ones_like(self.index_matrix)
for layer_z_s, layer_z_t, layer_kernels, theta in zip(z_s, z_t, self.kernels, self.thetas):
layer_features = torch.cat([layer_z_s, layer_z_t], dim=0)
layer_features = theta(layer_features)
kernel_matrix *= sum(
[kernel(layer_features) for kernel in layer_kernels]) # Add up the matrix of each kernel
# Add 2 / (n-1) to make up for the value on the diagonal
# to ensure loss is positive in the non-linear version
loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1)
return loss
class Theta(nn.Module):
"""
maximize loss respect to :math:`\theta`
minimize loss respect to features
"""
def __init__(self, dim: int):
super(Theta, self).__init__()
self.grl1 = GradientReverseLayer()
self.grl2 = GradientReverseLayer()
self.layer1 = nn.Linear(dim, dim)
nn.init.eye_(self.layer1.weight)
nn.init.zeros_(self.layer1.bias)
def forward(self, features: torch.Tensor) -> torch.Tensor:
features = self.grl1(features)
return self.grl2(self.layer1(features))
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
================================================
FILE: tllib/alignment/mcd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import torch.nn as nn
import torch
def classifier_discrepancy(predictions1: torch.Tensor, predictions2: torch.Tensor) -> torch.Tensor:
r"""The `Classifier Discrepancy` in
`Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR 2018) `_.
The classfier discrepancy between predictions :math:`p_1` and :math:`p_2` can be described as:
.. math::
d(p_1, p_2) = \dfrac{1}{K} \sum_{k=1}^K | p_{1k} - p_{2k} |,
where K is number of classes.
Args:
predictions1 (torch.Tensor): Classifier predictions :math:`p_1`. Expected to contain raw, normalized scores for each class
predictions2 (torch.Tensor): Classifier predictions :math:`p_2`
"""
return torch.mean(torch.abs(predictions1 - predictions2))
def entropy(predictions: torch.Tensor) -> torch.Tensor:
r"""Entropy of N predictions :math:`(p_1, p_2, ..., p_N)`.
The definition is:
.. math::
d(p_1, p_2, ..., p_N) = -\dfrac{1}{K} \sum_{k=1}^K \log \left( \dfrac{1}{N} \sum_{i=1}^N p_{ik} \right)
where K is number of classes.
.. note::
This entropy function is specifically used in MCD and different from the usual :meth:`~tllib.modules.entropy.entropy` function.
Args:
predictions (torch.Tensor): Classifier predictions. Expected to contain raw, normalized scores for each class
"""
return -torch.mean(torch.log(torch.mean(predictions, 0) + 1e-6))
class ImageClassifierHead(nn.Module):
r"""Classifier Head for MCD.
Args:
in_features (int): Dimension of input features
num_classes (int): Number of classes
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024
Shape:
- Inputs: :math:`(minibatch, F)` where F = `in_features`.
- Output: :math:`(minibatch, C)` where C = `num_classes`.
"""
def __init__(self, in_features: int, num_classes: int, bottleneck_dim: Optional[int] = 1024, pool_layer=None):
super(ImageClassifierHead, self).__init__()
self.num_classes = num_classes
if pool_layer is None:
self.pool_layer = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
else:
self.pool_layer = pool_layer
self.head = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(in_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(bottleneck_dim, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Linear(bottleneck_dim, num_classes)
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.head(self.pool_layer(inputs))
================================================
FILE: tllib/alignment/mdd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, List, Dict, Tuple, Callable
import torch.nn as nn
import torch.nn.functional as F
import torch
from tllib.modules.grl import WarmStartGradientReverseLayer
class MarginDisparityDiscrepancy(nn.Module):
r"""The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) `_.
MDD can measure the distribution discrepancy in domain adaptation.
The :math:`y^s` and :math:`y^t` are logits output by the main head on the source and target domain respectively.
The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial head.
The definition can be described as:
.. math::
\mathcal{D}_{\gamma}(\hat{\mathcal{S}}, \hat{\mathcal{T}}) =
-\gamma \mathbb{E}_{y^s, y_{adv}^s \sim\hat{\mathcal{S}}} L_s (y^s, y_{adv}^s) +
\mathbb{E}_{y^t, y_{adv}^t \sim\hat{\mathcal{T}}} L_t (y^t, y_{adv}^t),
where :math:`\gamma` is a margin hyper-parameter, :math:`L_s` refers to the disparity function defined on the source domain
and :math:`L_t` refers to the disparity function defined on the target domain.
Args:
source_disparity (callable): The disparity function defined on the source domain, :math:`L_s`.
target_disparity (callable): The disparity function defined on the target domain, :math:`L_t`.
margin (float): margin :math:`\gamma`. Default: 4
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Inputs:
- y_s: output :math:`y^s` by the main head on the source domain
- y_s_adv: output :math:`y^s` by the adversarial head on the source domain
- y_t: output :math:`y^t` by the main head on the target domain
- y_t_adv: output :math:`y_{adv}^t` by the adversarial head on the target domain
- w_s (optional): instance weights for source domain
- w_t (optional): instance weights for target domain
Examples::
>>> num_outputs = 2
>>> batch_size = 10
>>> loss = MarginDisparityDiscrepancy(margin=4., source_disparity=F.l1_loss, target_disparity=F.l1_loss)
>>> # output from source domain and target domain
>>> y_s, y_t = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)
>>> # adversarial output from source domain and target domain
>>> y_s_adv, y_t_adv = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)
>>> output = loss(y_s, y_s_adv, y_t, y_t_adv)
"""
def __init__(self, source_disparity: Callable, target_disparity: Callable,
margin: Optional[float] = 4, reduction: Optional[str] = 'mean'):
super(MarginDisparityDiscrepancy, self).__init__()
self.margin = margin
self.reduction = reduction
self.source_disparity = source_disparity
self.target_disparity = target_disparity
def forward(self, y_s: torch.Tensor, y_s_adv: torch.Tensor, y_t: torch.Tensor, y_t_adv: torch.Tensor,
w_s: Optional[torch.Tensor] = None, w_t: Optional[torch.Tensor] = None) -> torch.Tensor:
source_loss = -self.margin * self.source_disparity(y_s, y_s_adv)
target_loss = self.target_disparity(y_t, y_t_adv)
if w_s is None:
w_s = torch.ones_like(source_loss)
source_loss = source_loss * w_s
if w_t is None:
w_t = torch.ones_like(target_loss)
target_loss = target_loss * w_t
loss = source_loss + target_loss
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
class ClassificationMarginDisparityDiscrepancy(MarginDisparityDiscrepancy):
r"""
The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) `_.
It measures the distribution discrepancy in domain adaptation
for classification.
When margin is equal to 1, it's also called disparity discrepancy (DD).
The :math:`y^s` and :math:`y^t` are logits output by the main classifier on the source and target domain respectively.
The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial classifier.
They are expected to contain raw, unnormalized scores for each class.
The definition can be described as:
.. math::
\mathcal{D}_{\gamma}(\hat{\mathcal{S}}, \hat{\mathcal{T}}) =
\gamma \mathbb{E}_{y^s, y_{adv}^s \sim\hat{\mathcal{S}}} \log\left(\frac{\exp(y_{adv}^s[h_{y^s}])}{\sum_j \exp(y_{adv}^s[j])}\right) +
\mathbb{E}_{y^t, y_{adv}^t \sim\hat{\mathcal{T}}} \log\left(1-\frac{\exp(y_{adv}^t[h_{y^t}])}{\sum_j \exp(y_{adv}^t[j])}\right),
where :math:`\gamma` is a margin hyper-parameter and :math:`h_y` refers to the predicted label when the logits output is :math:`y`.
You can see more details in `Bridging Theory and Algorithm for Domain Adaptation `_.
Args:
margin (float): margin :math:`\gamma`. Default: 4
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Inputs:
- y_s: logits output :math:`y^s` by the main classifier on the source domain
- y_s_adv: logits output :math:`y^s` by the adversarial classifier on the source domain
- y_t: logits output :math:`y^t` by the main classifier on the target domain
- y_t_adv: logits output :math:`y_{adv}^t` by the adversarial classifier on the target domain
Shape:
- Inputs: :math:`(minibatch, C)` where C = number of classes, or :math:`(minibatch, C, d_1, d_2, ..., d_K)`
with :math:`K \geq 1` in the case of `K`-dimensional loss.
- Output: scalar. If :attr:`reduction` is ``'none'``, then the same size as the target: :math:`(minibatch)`, or
:math:`(minibatch, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss.
Examples::
>>> num_classes = 2
>>> batch_size = 10
>>> loss = ClassificationMarginDisparityDiscrepancy(margin=4.)
>>> # logits output from source domain and target domain
>>> y_s, y_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
>>> # adversarial logits output from source domain and target domain
>>> y_s_adv, y_t_adv = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
>>> output = loss(y_s, y_s_adv, y_t, y_t_adv)
"""
def __init__(self, margin: Optional[float] = 4, **kwargs):
def source_discrepancy(y: torch.Tensor, y_adv: torch.Tensor):
_, prediction = y.max(dim=1)
return F.cross_entropy(y_adv, prediction, reduction='none')
def target_discrepancy(y: torch.Tensor, y_adv: torch.Tensor):
_, prediction = y.max(dim=1)
return -F.nll_loss(shift_log(1. - F.softmax(y_adv, dim=1)), prediction, reduction='none')
super(ClassificationMarginDisparityDiscrepancy, self).__init__(source_discrepancy, target_discrepancy, margin,
**kwargs)
class RegressionMarginDisparityDiscrepancy(MarginDisparityDiscrepancy):
r"""
The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) `_.
It measures the distribution discrepancy in domain adaptation
for regression.
The :math:`y^s` and :math:`y^t` are logits output by the main regressor on the source and target domain respectively.
The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial regressor.
They are expected to contain ``normalized`` values for each factors.
The definition can be described as:
.. math::
\mathcal{D}_{\gamma}(\hat{\mathcal{S}}, \hat{\mathcal{T}}) =
-\gamma \mathbb{E}_{y^s, y_{adv}^s \sim\hat{\mathcal{S}}} L (y^s, y_{adv}^s) +
\mathbb{E}_{y^t, y_{adv}^t \sim\hat{\mathcal{T}}} L (y^t, y_{adv}^t),
where :math:`\gamma` is a margin hyper-parameter and :math:`L` refers to the disparity function defined on both domains.
You can see more details in `Bridging Theory and Algorithm for Domain Adaptation `_.
Args:
loss_function (callable): The disparity function defined on both domains, :math:`L`.
margin (float): margin :math:`\gamma`. Default: 1
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Inputs:
- y_s: logits output :math:`y^s` by the main regressor on the source domain
- y_s_adv: logits output :math:`y^s` by the adversarial regressor on the source domain
- y_t: logits output :math:`y^t` by the main regressor on the target domain
- y_t_adv: logits output :math:`y_{adv}^t` by the adversarial regressor on the target domain
Shape:
- Inputs: :math:`(minibatch, F)` where F = number of factors, or :math:`(minibatch, F, d_1, d_2, ..., d_K)`
with :math:`K \geq 1` in the case of `K`-dimensional loss.
- Output: scalar. The same size as the target: :math:`(minibatch)`, or
:math:`(minibatch, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss.
Examples::
>>> num_outputs = 2
>>> batch_size = 10
>>> loss = RegressionMarginDisparityDiscrepancy(margin=4., loss_function=F.l1_loss)
>>> # output from source domain and target domain
>>> y_s, y_t = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)
>>> # adversarial output from source domain and target domain
>>> y_s_adv, y_t_adv = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs)
>>> output = loss(y_s, y_s_adv, y_t, y_t_adv)
"""
def __init__(self, margin: Optional[float] = 1, loss_function=F.l1_loss, **kwargs):
def source_discrepancy(y: torch.Tensor, y_adv: torch.Tensor):
return loss_function(y_adv, y.detach(), reduction='none')
def target_discrepancy(y: torch.Tensor, y_adv: torch.Tensor):
return loss_function(y_adv, y.detach(), reduction='none')
super(RegressionMarginDisparityDiscrepancy, self).__init__(source_discrepancy, target_discrepancy, margin,
**kwargs)
def shift_log(x: torch.Tensor, offset: Optional[float] = 1e-6) -> torch.Tensor:
r"""
First shift, then calculate log, which can be described as:
.. math::
y = \max(\log(x+\text{offset}), 0)
Used to avoid the gradient explosion problem in log(x) function when x=0.
Args:
x (torch.Tensor): input tensor
offset (float, optional): offset size. Default: 1e-6
.. note::
Input tensor falls in [0., 1.] and the output tensor falls in [-log(offset), 0]
"""
return torch.log(torch.clamp(x + offset, max=1.))
class GeneralModule(nn.Module):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: nn.Module,
head: nn.Module, adv_head: nn.Module, grl: Optional[WarmStartGradientReverseLayer] = None,
finetune: Optional[bool] = True):
super(GeneralModule, self).__init__()
self.backbone = backbone
self.num_classes = num_classes
self.bottleneck = bottleneck
self.head = head
self.adv_head = adv_head
self.finetune = finetune
self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000,
auto_step=False) if grl is None else grl
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""""""
features = self.backbone(x)
features = self.bottleneck(features)
outputs = self.head(features)
features_adv = self.grl_layer(features)
outputs_adv = self.adv_head(features_adv)
if self.training:
return outputs, outputs_adv
else:
return outputs
def step(self):
"""
Gradually increase :math:`\lambda` in GRL layer.
"""
self.grl_layer.step()
def get_parameters(self, base_lr=1.0) -> List[Dict]:
"""
Return a parameters list which decides optimization hyper-parameters,
such as the relative learning rate of each layer.
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else base_lr},
{"params": self.bottleneck.parameters(), "lr": base_lr},
{"params": self.head.parameters(), "lr": base_lr},
{"params": self.adv_head.parameters(), "lr": base_lr}
]
return params
class ImageClassifier(GeneralModule):
r"""Classifier for MDD.
Classifier for MDD has one backbone, one bottleneck, while two classifier heads.
The first classifier head is used for final predictions.
The adversarial classifier head is only used when calculating MarginDisparityDiscrepancy.
Args:
backbone (torch.nn.Module): Any backbone to extract 1-d features from data
num_classes (int): Number of classes
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024
width (int, optional): Feature dimension of the classifier head. Default: 1024
grl (nn.Module): Gradient reverse layer. Will use default parameters if None. Default: None.
finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True
Inputs:
- x (tensor): input data
Outputs:
- outputs: logits outputs by the main classifier
- outputs_adv: logits outputs by the adversarial classifier
Shape:
- x: :math:`(minibatch, *)`, same shape as the input of the `backbone`.
- outputs, outputs_adv: :math:`(minibatch, C)`, where C means the number of classes.
.. note::
Remember to call function `step()` after function `forward()` **during training phase**! For instance,
>>> # x is inputs, classifier is an ImageClassifier
>>> outputs, outputs_adv = classifier(x)
>>> classifier.step()
"""
def __init__(self, backbone: nn.Module, num_classes: int,
bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024,
grl: Optional[WarmStartGradientReverseLayer] = None, finetune=True, pool_layer=None):
grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000,
auto_step=False) if grl is None else grl
if pool_layer is None:
pool_layer = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
bottleneck = nn.Sequential(
pool_layer,
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
bottleneck[1].weight.data.normal_(0, 0.005)
bottleneck[1].bias.data.fill_(0.1)
# The classifier head used for final predictions.
head = nn.Sequential(
nn.Linear(bottleneck_dim, width),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(width, num_classes)
)
# The adversarial classifier head
adv_head = nn.Sequential(
nn.Linear(bottleneck_dim, width),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(width, num_classes)
)
for dep in range(2):
head[dep * 3].weight.data.normal_(0, 0.01)
head[dep * 3].bias.data.fill_(0.0)
adv_head[dep * 3].weight.data.normal_(0, 0.01)
adv_head[dep * 3].bias.data.fill_(0.0)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck,
head, adv_head, grl_layer, finetune)
class ImageRegressor(GeneralModule):
r"""Regressor for MDD.
Regressor for MDD has one backbone, one bottleneck, while two regressor heads.
The first regressor head is used for final predictions.
The adversarial regressor head is only used when calculating MarginDisparityDiscrepancy.
Args:
backbone (torch.nn.Module): Any backbone to extract 1-d features from data
num_factors (int): Number of factors
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024
width (int, optional): Feature dimension of the classifier head. Default: 1024
finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True
Inputs:
- x (Tensor): input data
Outputs: (outputs, outputs_adv)
- outputs: outputs by the main regressor
- outputs_adv: outputs by the adversarial regressor
Shape:
- x: :math:`(minibatch, *)`, same shape as the input of the `backbone`.
- outputs, outputs_adv: :math:`(minibatch, F)`, where F means the number of factors.
.. note::
Remember to call function `step()` after function `forward()` **during training phase**! For instance,
>>> # x is inputs, regressor is an ImageRegressor
>>> outputs, outputs_adv = regressor(x)
>>> regressor.step()
"""
def __init__(self, backbone: nn.Module, num_factors: int, bottleneck = None, head=None, adv_head=None,
bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024, finetune=True):
grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False)
if bottleneck is None:
bottleneck = nn.Sequential(
nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(bottleneck_dim),
nn.ReLU(),
)
# The regressor head used for final predictions.
if head is None:
head = nn.Sequential(
nn.Conv2d(bottleneck_dim, width, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(width),
nn.ReLU(),
nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(width),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(width, num_factors),
nn.Sigmoid()
)
for layer in head:
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, 0, 0.01)
nn.init.constant_(layer.bias, 0)
# The adversarial regressor head
if adv_head is None:
adv_head = nn.Sequential(
nn.Conv2d(bottleneck_dim, width, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(width),
nn.ReLU(),
nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(width),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten(),
nn.Linear(width, num_factors),
nn.Sigmoid()
)
for layer in adv_head:
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear):
nn.init.normal_(layer.weight, 0, 0.01)
nn.init.constant_(layer.bias, 0)
super(ImageRegressor, self).__init__(backbone, num_factors, bottleneck,
head, adv_head, grl_layer, finetune)
self.num_factors = num_factors
================================================
FILE: tllib/alignment/osbp.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.classifier import Classifier as ClassifierBase
from tllib.modules.grl import GradientReverseLayer
class UnknownClassBinaryCrossEntropy(nn.Module):
r"""
Binary cross entropy loss to make a boundary for unknown samples, proposed by
`Open Set Domain Adaptation by Backpropagation (ECCV 2018) `_.
Given a sample on target domain :math:`x_t` and its classifcation outputs :math:`y`, the binary cross entropy
loss is defined as
.. math::
L_{\text{adv}}(x_t) = -t \text{log}(p(y=C+1|x_t)) - (1-t)\text{log}(1-p(y=C+1|x_t))
where t is a hyper-parameter and C is the number of known classes.
Args:
t (float): Predefined hyper-parameter. Default: 0.5
Inputs:
- y (tensor): classification outputs (before softmax).
Shape:
- y: :math:`(minibatch, C+1)` where C is the number of known classes.
- Outputs: scalar
"""
def __init__(self, t: Optional[float]=0.5):
super(UnknownClassBinaryCrossEntropy, self).__init__()
self.t = t
def forward(self, y):
# y : N x (C+1)
softmax_output = F.softmax(y, dim=1)
unknown_class_prob = softmax_output[:, -1].contiguous().view(-1, 1)
known_class_prob = 1. - unknown_class_prob
unknown_target = torch.ones((y.size(0), 1)).to(y.device) * self.t
known_target = 1. - unknown_target
return - torch.mean(unknown_target * torch.log(unknown_class_prob + 1e-6)) \
- torch.mean(known_target * torch.log(known_class_prob + 1e-6))
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(),
nn.Linear(bottleneck_dim, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
self.grl = GradientReverseLayer()
def forward(self, x: torch.Tensor, grad_reverse: Optional[bool] = False):
features = self.pool_layer(self.backbone(x))
features = self.bottleneck(features)
if grad_reverse:
features = self.grl(features)
outputs = self.head(features)
if self.training:
return outputs, features
else:
return outputs
================================================
FILE: tllib/alignment/regda.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tllib.modules.gl import WarmStartGradientLayer
from tllib.utils.metric.keypoint_detection import get_max_preds
class FastPseudoLabelGenerator2d(nn.Module):
def __init__(self, sigma=2):
super().__init__()
self.sigma = sigma
def forward(self, heatmap: torch.Tensor):
heatmap = heatmap.detach()
height, width = heatmap.shape[-2:]
idx = heatmap.flatten(-2).argmax(dim=-1) # B, K
pred_h, pred_w = idx.div(width, rounding_mode='floor'), idx.remainder(width) # B, K
delta_h = torch.arange(height, device=heatmap.device) - pred_h.unsqueeze(-1) # B, K, H
delta_w = torch.arange(width, device=heatmap.device) - pred_w.unsqueeze(-1) # B, K, W
gaussian = (delta_h.square().unsqueeze(-1) + delta_w.square().unsqueeze(-2)).div(-2 * self.sigma * self.sigma).exp() # B, K, H, W
ground_truth = F.threshold(gaussian, threshold=1e-2, value=0.)
ground_false = (ground_truth.sum(dim=1, keepdim=True) - ground_truth).clamp(0., 1.)
return ground_truth, ground_false
class PseudoLabelGenerator2d(nn.Module):
"""
Generate ground truth heatmap and ground false heatmap from a prediction.
Args:
num_keypoints (int): Number of keypoints
height (int): height of the heatmap. Default: 64
width (int): width of the heatmap. Default: 64
sigma (int): sigma parameter when generate the heatmap. Default: 2
Inputs:
- y: predicted heatmap
Outputs:
- ground_truth: heatmap conforming to Gaussian distribution
- ground_false: ground false heatmap
Shape:
- y: :math:`(minibatch, K, H, W)` where K means the number of keypoints,
H and W is the height and width of the heatmap respectively.
- ground_truth: :math:`(minibatch, K, H, W)`
- ground_false: :math:`(minibatch, K, H, W)`
"""
def __init__(self, num_keypoints, height=64, width=64, sigma=2):
super(PseudoLabelGenerator2d, self).__init__()
self.height = height
self.width = width
self.sigma = sigma
heatmaps = np.zeros((width, height, height, width), dtype=np.float32)
tmp_size = sigma * 3
for mu_x in range(width):
for mu_y in range(height):
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
# Generate gaussian
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], width) - ul[0]
g_y = max(0, -ul[1]), min(br[1], height) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], width)
img_y = max(0, ul[1]), min(br[1], height)
heatmaps[mu_x][mu_y][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
self.heatmaps = heatmaps
self.false_matrix = 1. - np.eye(num_keypoints, dtype=np.float32)
def forward(self, y):
B, K, H, W = y.shape
y = y.detach()
preds, max_vals = get_max_preds(y.cpu().numpy()) # B x K x (x, y)
preds = preds.reshape(-1, 2).astype(np.int)
ground_truth = self.heatmaps[preds[:, 0], preds[:, 1], :, :].copy().reshape(B, K, H, W).copy()
ground_false = ground_truth.reshape(B, K, -1).transpose((0, 2, 1))
ground_false = ground_false.dot(self.false_matrix).clip(max=1., min=0.).transpose((0, 2, 1)).reshape(B, K, H, W).copy()
return torch.from_numpy(ground_truth).to(y.device), torch.from_numpy(ground_false).to(y.device)
class RegressionDisparity(nn.Module):
"""
Regression Disparity proposed by `Regressive Domain Adaptation for Unsupervised Keypoint Detection (CVPR 2021) `_.
Args:
pseudo_label_generator (PseudoLabelGenerator2d): generate ground truth heatmap and ground false heatmap
from a prediction.
criterion (torch.nn.Module): the loss function to calculate distance between two predictions.
Inputs:
- y: output by the main head
- y_adv: output by the adversarial head
- weight (optional): instance weights
- mode (str): whether minimize the disparity or maximize the disparity. Choices includes ``min``, ``max``.
Default: ``min``.
Shape:
- y: :math:`(minibatch, K, H, W)` where K means the number of keypoints,
H and W is the height and width of the heatmap respectively.
- y_adv: :math:`(minibatch, K, H, W)`
- weight: :math:`(minibatch, K)`.
- Output: depends on the ``criterion``.
Examples::
>>> num_keypoints = 5
>>> batch_size = 10
>>> H = W = 64
>>> pseudo_label_generator = PseudoLabelGenerator2d(num_keypoints)
>>> from tllibvision.models.keypoint_detection.loss import JointsKLLoss
>>> loss = RegressionDisparity(pseudo_label_generator, JointsKLLoss())
>>> # output from source domain and target domain
>>> y_s, y_t = torch.randn(batch_size, num_keypoints, H, W), torch.randn(batch_size, num_keypoints, H, W)
>>> # adversarial output from source domain and target domain
>>> y_s_adv, y_t_adv = torch.randn(batch_size, num_keypoints, H, W), torch.randn(batch_size, num_keypoints, H, W)
>>> # minimize regression disparity on source domain
>>> output = loss(y_s, y_s_adv, mode='min')
>>> # maximize regression disparity on target domain
>>> output = loss(y_t, y_t_adv, mode='max')
"""
def __init__(self, pseudo_label_generator: PseudoLabelGenerator2d, criterion: nn.Module):
super(RegressionDisparity, self).__init__()
self.criterion = criterion
self.pseudo_label_generator = pseudo_label_generator
def forward(self, y, y_adv, weight=None, mode='min'):
assert mode in ['min', 'max']
ground_truth, ground_false = self.pseudo_label_generator(y.detach())
self.ground_truth = ground_truth
self.ground_false = ground_false
if mode == 'min':
return self.criterion(y_adv, ground_truth, weight)
else:
return self.criterion(y_adv, ground_false, weight)
class PoseResNet2d(nn.Module):
"""
Pose ResNet for RegDA has one backbone, one upsampling, while two regression heads.
Args:
backbone (torch.nn.Module): Backbone to extract 2-d features from data
upsampling (torch.nn.Module): Layer to upsample image feature to heatmap size
feature_dim (int): The dimension of the features from upsampling layer.
num_keypoints (int): Number of keypoints
gl (WarmStartGradientLayer):
finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True
num_head_layers (int): Number of head layers. Default: 2
Inputs:
- x (tensor): input data
Outputs:
- outputs: logits outputs by the main regressor
- outputs_adv: logits outputs by the adversarial regressor
Shape:
- x: :math:`(minibatch, *)`, same shape as the input of the `backbone`.
- outputs, outputs_adv: :math:`(minibatch, K, H, W)`, where K means the number of keypoints.
.. note::
Remember to call function `step()` after function `forward()` **during training phase**! For instance,
>>> # x is inputs, model is an PoseResNet
>>> outputs, outputs_adv = model(x)
>>> model.step()
"""
def __init__(self, backbone, upsampling, feature_dim, num_keypoints,
gl: Optional[WarmStartGradientLayer] = None, finetune: Optional[bool] = True, num_head_layers=2):
super(PoseResNet2d, self).__init__()
self.backbone = backbone
self.upsampling = upsampling
self.head = self._make_head(num_head_layers, feature_dim, num_keypoints)
self.head_adv = self._make_head(num_head_layers, feature_dim, num_keypoints)
self.finetune = finetune
self.gl_layer = WarmStartGradientLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False) if gl is None else gl
@staticmethod
def _make_head(num_layers, channel_dim, num_keypoints):
layers = []
for i in range(num_layers-1):
layers.extend([
nn.Conv2d(channel_dim, channel_dim, 3, 1, 1),
nn.BatchNorm2d(channel_dim),
nn.ReLU(),
])
layers.append(
nn.Conv2d(
in_channels=channel_dim,
out_channels=num_keypoints,
kernel_size=1,
stride=1,
padding=0
)
)
layers = nn.Sequential(*layers)
for m in layers.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.001)
nn.init.constant_(m.bias, 0)
return layers
def forward(self, x):
x = self.backbone(x)
f = self.upsampling(x)
f_adv = self.gl_layer(f)
y = self.head(f)
y_adv = self.head_adv(f_adv)
if self.training:
return y, y_adv
else:
return y
def get_parameters(self, lr=1.):
return [
{'params': self.backbone.parameters(), 'lr': 0.1 * lr if self.finetune else lr},
{'params': self.upsampling.parameters(), 'lr': lr},
{'params': self.head.parameters(), 'lr': lr},
{'params': self.head_adv.parameters(), 'lr': lr},
]
def step(self):
"""Call step() each iteration during training.
Will increase :math:`\lambda` in GL layer.
"""
self.gl_layer.step()
================================================
FILE: tllib/alignment/rsd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
import torch
class RepresentationSubspaceDistance(nn.Module):
"""
`Representation Subspace Distance (ICML 2021) `_
Args:
trade_off (float): The trade-off value between Representation Subspace Distance
and Base Mismatch Penalization. Default: 0.1
Inputs:
- f_s (tensor): feature representations on source domain, :math:`f^s`
- f_t (tensor): feature representations on target domain, :math:`f^t`
"""
def __init__(self, trade_off=0.1):
super(RepresentationSubspaceDistance, self).__init__()
self.trade_off = trade_off
def forward(self, f_s, f_t):
U_s, _, _ = torch.svd(f_s.t())
U_t, _, _ = torch.svd(f_t.t())
P_s, cosine, P_t = torch.svd(torch.mm(U_s.t(), U_t))
sine = torch.sqrt(1 - torch.pow(cosine, 2))
rsd = torch.norm(sine, 1) # Representation Subspace Distance
bmp = torch.norm(torch.abs(P_s) - torch.abs(P_t), 2) # Base Mismatch Penalization
return rsd + self.trade_off * bmp
================================================
FILE: tllib/modules/__init__.py
================================================
from .classifier import *
from .regressor import *
from .grl import *
from .domain_discriminator import *
from .kernels import *
from .entropy import *
__all__ = ['classifier', 'regressor', 'grl', 'kernels', 'domain_discriminator', 'entropy']
================================================
FILE: tllib/modules/classifier.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Tuple, Optional, List, Dict
import torch.nn as nn
import torch
__all__ = ['Classifier']
class Classifier(nn.Module):
"""A generic Classifier class for domain adaptation.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_classes (int): Number of classes
bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1
head (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default
finetune (bool): Whether finetune the classifier or train from scratch. Default: True
.. note::
Different classifiers are used in different domain adaptation algorithms to achieve better accuracy
respectively, and we provide a suggested `Classifier` for different algorithms.
Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with
the domain adaptation algorithm in this algorithm library.
.. note::
The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy
by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`.
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
- predictions: classifier's predictions
- features: features after `bottleneck` layer and before `head` layer
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- predictions: (minibatch, `num_classes`)
- features: (minibatch, `features_dim`)
"""
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None,
bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True, pool_layer=None):
super(Classifier, self).__init__()
self.backbone = backbone
self.num_classes = num_classes
if pool_layer is None:
self.pool_layer = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
else:
self.pool_layer = pool_layer
if bottleneck is None:
self.bottleneck = nn.Identity()
self._features_dim = backbone.out_features
else:
self.bottleneck = bottleneck
assert bottleneck_dim > 0
self._features_dim = bottleneck_dim
if head is None:
self.head = nn.Linear(self._features_dim, num_classes)
else:
self.head = head
self.finetune = finetune
@property
def features_dim(self) -> int:
"""The dimension of features before the final `head` layer"""
return self._features_dim
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""""""
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
predictions = self.head(f)
if self.training:
return predictions, f
else:
return predictions
def get_parameters(self, base_lr=1.0) -> List[Dict]:
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
]
return params
class ImageClassifier(Classifier):
pass
================================================
FILE: tllib/modules/domain_discriminator.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import List, Dict
import torch.nn as nn
__all__ = ['DomainDiscriminator']
class DomainDiscriminator(nn.Sequential):
r"""Domain discriminator model from
`Domain-Adversarial Training of Neural Networks (ICML 2015) `_
Distinguish whether the input features come from the source domain or the target domain.
The source domain label is 1 and the target domain label is 0.
Args:
in_feature (int): dimension of the input feature
hidden_size (int): dimension of the hidden features
batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`.
Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True.
Shape:
- Inputs: (minibatch, `in_feature`)
- Outputs: :math:`(minibatch, 1)`
"""
def __init__(self, in_feature: int, hidden_size: int, batch_norm=True, sigmoid=True):
if sigmoid:
final_layer = nn.Sequential(
nn.Linear(hidden_size, 1),
nn.Sigmoid()
)
else:
final_layer = nn.Linear(hidden_size, 2)
if batch_norm:
super(DomainDiscriminator, self).__init__(
nn.Linear(in_feature, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
final_layer
)
else:
super(DomainDiscriminator, self).__init__(
nn.Linear(in_feature, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
final_layer
)
def get_parameters(self) -> List[Dict]:
return [{"params": self.parameters(), "lr": 1.}]
================================================
FILE: tllib/modules/entropy.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
def entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor:
r"""Entropy of prediction.
The definition is:
.. math::
entropy(p) = - \sum_{c=1}^C p_c \log p_c
where C is number of classes.
Args:
predictions (tensor): Classifier predictions. Expected to contain raw, normalized scores for each class
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output. Default: ``'mean'``
Shape:
- predictions: :math:`(minibatch, C)` where C means the number of classes.
- Output: :math:`(minibatch, )` by default. If :attr:`reduction` is ``'mean'``, then scalar.
"""
epsilon = 1e-5
H = -predictions * torch.log(predictions + epsilon)
H = H.sum(dim=1)
if reduction == 'mean':
return H.mean()
else:
return H
================================================
FILE: tllib/modules/gl.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, Any, Tuple
import numpy as np
import torch.nn as nn
from torch.autograd import Function
import torch
class GradientFunction(Function):
@staticmethod
def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
ctx.coeff = coeff
output = input * 1.0
return output
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
return grad_output * ctx.coeff, None
class WarmStartGradientLayer(nn.Module):
"""Warm Start Gradient Layer :math:`\mathcal{R}(x)` with warm start
The forward and backward behaviours are:
.. math::
\mathcal{R}(x) = x,
\dfrac{ d\mathcal{R}} {dx} = \lambda I.
:math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule:
.. math::
\lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo
where :math:`i` is the iteration step.
Parameters:
- **alpha** (float, optional): :math:`α`. Default: 1.0
- **lo** (float, optional): Initial value of :math:`\lambda`. Default: 0.0
- **hi** (float, optional): Final value of :math:`\lambda`. Default: 1.0
- **max_iters** (int, optional): :math:`N`. Default: 1000
- **auto_step** (bool, optional): If True, increase :math:`i` each time `forward` is called.
Otherwise use function `step` to increase :math:`i`. Default: False
"""
def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1.,
max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False):
super(WarmStartGradientLayer, self).__init__()
self.alpha = alpha
self.lo = lo
self.hi = hi
self.iter_num = 0
self.max_iters = max_iters
self.auto_step = auto_step
def forward(self, input: torch.Tensor) -> torch.Tensor:
""""""
coeff = np.float(
2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters))
- (self.hi - self.lo) + self.lo
)
if self.auto_step:
self.step()
return GradientFunction.apply(input, coeff)
def step(self):
"""Increase iteration number :math:`i` by 1"""
self.iter_num += 1
================================================
FILE: tllib/modules/grl.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, Any, Tuple
import numpy as np
import torch.nn as nn
from torch.autograd import Function
import torch
class GradientReverseFunction(Function):
@staticmethod
def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor:
ctx.coeff = coeff
output = input * 1.0
return output
@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]:
return grad_output.neg() * ctx.coeff, None
class GradientReverseLayer(nn.Module):
def __init__(self):
super(GradientReverseLayer, self).__init__()
def forward(self, *input):
return GradientReverseFunction.apply(*input)
class WarmStartGradientReverseLayer(nn.Module):
"""Gradient Reverse Layer :math:`\mathcal{R}(x)` with warm start
The forward and backward behaviours are:
.. math::
\mathcal{R}(x) = x,
\dfrac{ d\mathcal{R}} {dx} = - \lambda I.
:math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule:
.. math::
\lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo
where :math:`i` is the iteration step.
Args:
alpha (float, optional): :math:`α`. Default: 1.0
lo (float, optional): Initial value of :math:`\lambda`. Default: 0.0
hi (float, optional): Final value of :math:`\lambda`. Default: 1.0
max_iters (int, optional): :math:`N`. Default: 1000
auto_step (bool, optional): If True, increase :math:`i` each time `forward` is called.
Otherwise use function `step` to increase :math:`i`. Default: False
"""
def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1.,
max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False):
super(WarmStartGradientReverseLayer, self).__init__()
self.alpha = alpha
self.lo = lo
self.hi = hi
self.iter_num = 0
self.max_iters = max_iters
self.auto_step = auto_step
def forward(self, input: torch.Tensor) -> torch.Tensor:
""""""
coeff = np.float(
2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters))
- (self.hi - self.lo) + self.lo
)
if self.auto_step:
self.step()
return GradientReverseFunction.apply(input, coeff)
def step(self):
"""Increase iteration number :math:`i` by 1"""
self.iter_num += 1
================================================
FILE: tllib/modules/kernels.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import torch
import torch.nn as nn
__all__ = ['GaussianKernel']
class GaussianKernel(nn.Module):
r"""Gaussian Kernel Matrix
Gaussian Kernel k is defined by
.. math::
k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right)
where :math:`x_1, x_2 \in R^d` are 1-d tensors.
Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),`
.. math::
K(X)_{i,j} = k(x_i, x_j)
Also by default, during training this layer keeps running estimates of the
mean of L2 distances, which are then used to set hyperparameter :math:`\sigma`.
Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`.
If :attr:`track_running_stats` is set to ``False``, this layer then does not
keep running estimates, and use a fixed :math:`\sigma` instead.
Args:
sigma (float, optional): bandwidth :math:`\sigma`. Default: None
track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`.
Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True``
alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True``
Inputs:
- X (tensor): input group :math:`X`
Shape:
- Inputs: :math:`(minibatch, F)` where F means the dimension of input features.
- Outputs: :math:`(minibatch, minibatch)`
"""
def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True,
alpha: Optional[float] = 1.):
super(GaussianKernel, self).__init__()
assert track_running_stats or sigma is not None
self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None
self.track_running_stats = track_running_stats
self.alpha = alpha
def forward(self, X: torch.Tensor) -> torch.Tensor:
l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2)
if self.track_running_stats:
self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach())
return torch.exp(-l2_distance_square / (2 * self.sigma_square))
================================================
FILE: tllib/modules/loss.py
================================================
import torch.nn as nn
import torch
import torch.nn.functional as F
# version 1: use torch.autograd
class LabelSmoothSoftmaxCEV1(nn.Module):
'''
Adapted from https://github.com/CoinCheung/pytorch-loss
'''
def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-1):
super(LabelSmoothSoftmaxCEV1, self).__init__()
self.lb_smooth = lb_smooth
self.reduction = reduction
self.lb_ignore = ignore_index
self.log_softmax = nn.LogSoftmax(dim=1)
def forward(self, input, target):
'''
Same usage method as nn.CrossEntropyLoss:
>>> criteria = LabelSmoothSoftmaxCEV1()
>>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half
>>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t
>>> loss = criteria(logits, lbs)
'''
# overcome ignored label
logits = input.float() # use fp32 to avoid nan
with torch.no_grad():
num_classes = logits.size(1)
label = target.clone().detach()
ignore = label.eq(self.lb_ignore)
n_valid = ignore.eq(0).sum()
label[ignore] = 0
lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes
lb_one_hot = torch.empty_like(logits).fill_(
lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach()
logs = self.log_softmax(logits)
loss = -torch.sum(logs * lb_one_hot, dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
if self.reduction == 'sum':
loss = loss.sum()
return loss
class KnowledgeDistillationLoss(nn.Module):
"""Knowledge Distillation Loss.
Args:
T (double): Temperature. Default: 1.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'batchmean'``
Inputs:
- y_student (tensor): logits output of the student
- y_teacher (tensor): logits output of the teacher
Shape:
- y_student: (minibatch, `num_classes`)
- y_teacher: (minibatch, `num_classes`)
"""
def __init__(self, T=1., reduction='batchmean'):
super(KnowledgeDistillationLoss, self).__init__()
self.T = T
self.kl = nn.KLDivLoss(reduction=reduction)
def forward(self, y_student, y_teacher):
""""""
return self.kl(F.log_softmax(y_student / self.T, dim=-1), F.softmax(y_teacher / self.T, dim=-1))
================================================
FILE: tllib/modules/regressor.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Tuple, Optional, List, Dict
import torch.nn as nn
import torch
__all__ = ['Regressor']
class Regressor(nn.Module):
"""A generic Regressor class for domain adaptation.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_factors (int): Number of factors
bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1
head (torch.nn.Module, optional): Any classifier head. Use `nn.Linear` by default
finetune (bool): Whether finetune the classifier or train from scratch. Default: True
.. note::
The learning rate of this regressor is set 10 times to that of the feature extractor for better accuracy
by default. If you have other optimization strategies, please over-ride :meth:`~Regressor.get_parameters`.
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
- predictions: regressor's predictions
- features: features after `bottleneck` layer and before `head` layer
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- predictions: (minibatch, `num_factors`)
- features: (minibatch, `features_dim`)
"""
def __init__(self, backbone: nn.Module, num_factors: int, bottleneck: Optional[nn.Module] = None,
bottleneck_dim=-1, head: Optional[nn.Module] = None, finetune=True):
super(Regressor, self).__init__()
self.backbone = backbone
self.num_factors = num_factors
if bottleneck is None:
feature_dim = backbone.out_features
self.bottleneck = nn.Sequential(
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(feature_dim, feature_dim),
nn.ReLU(),
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
self._features_dim = feature_dim
else:
self.bottleneck = bottleneck
assert bottleneck_dim > 0
self._features_dim = bottleneck_dim
if head is None:
self.head = nn.Sequential(
nn.Linear(self._features_dim, num_factors),
nn.Sigmoid()
)
else:
self.head = head
self.finetune = finetune
@property
def features_dim(self) -> int:
"""The dimension of features before the final `head` layer"""
return self._features_dim
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""""""
f = self.backbone(x)
f = self.bottleneck(f)
predictions = self.head(f)
if self.training:
return predictions, f
else:
return predictions
def get_parameters(self, base_lr=1.0) -> List[Dict]:
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
]
return params
================================================
FILE: tllib/normalization/__init__.py
================================================
================================================
FILE: tllib/normalization/afn.py
================================================
"""
Modified from https://github.com/jihanyang/AFN
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from typing import Optional, List, Dict
import torch
import torch.nn as nn
import math
from tllib.modules.classifier import Classifier as ClassfierBase
class AdaptiveFeatureNorm(nn.Module):
r"""
The `Stepwise Adaptive Feature Norm loss (ICCV 2019) `_
Instead of using restrictive scalar R to match the corresponding feature norm, Stepwise Adaptive Feature Norm
is used in order to learn task-specific features with large norms in a progressive manner.
We denote parameters of backbone :math:`G` as :math:`\theta_g`, parameters of bottleneck :math:`F_f` as :math:`\theta_f`
, parameters of classifier head :math:`F_y` as :math:`\theta_y`, and features extracted from sample :math:`x_i` as
:math:`h(x_i;\theta)`. Full loss is calculated as follows
.. math::
L(\theta_g,\theta_f,\theta_y)=\frac{1}{n_s}\sum_{(x_i,y_i)\in D_s}L_y(x_i,y_i)+\frac{\lambda}{n_s+n_t}
\sum_{x_i\in D_s\cup D_t}L_d(h(x_i;\theta_0)+\Delta_r,h(x_i;\theta))\\
where :math:`L_y` denotes classification loss, :math:`L_d` denotes norm loss, :math:`\theta_0` and :math:`\theta`
represent the updated and updating model parameters in the last and current iterations respectively.
Args:
delta (float): positive residual scalar to control the feature norm enlargement.
Inputs:
- f (tensor): feature representations on source or target domain.
Shape:
- f: :math:`(N, F)` where F means the dimension of input features.
- Outputs: scalar.
Examples::
>>> adaptive_feature_norm = AdaptiveFeatureNorm(delta=1)
>>> f_s = torch.randn(32, 1000)
>>> f_t = torch.randn(32, 1000)
>>> norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t)
"""
def __init__(self, delta):
super(AdaptiveFeatureNorm, self).__init__()
self.delta = delta
def forward(self, f: torch.Tensor) -> torch.Tensor:
radius = f.norm(p=2, dim=1).detach()
assert radius.requires_grad == False
radius = radius + self.delta
loss = ((f.norm(p=2, dim=1) - radius) ** 2).mean()
return loss
class Block(nn.Module):
r"""
Basic building block for Image Classifier with structure: FC-BN-ReLU-Dropout.
We use :math:`L_2` preserved dropout layers.
Given mask probability :math:`p`, input :math:`x_k`, generated mask :math:`a_k`,
vanilla dropout layers calculate
.. math::
\hat{x}_k = a_k\frac{1}{1-p}x_k\\
While in :math:`L_2` preserved dropout layers
.. math::
\hat{x}_k = a_k\frac{1}{\sqrt{1-p}}x_k\\
Args:
in_features (int): Dimension of input features
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1000
dropout_p (float, optional): dropout probability. Default: 0.5
"""
def __init__(self, in_features: int, bottleneck_dim: Optional[int] = 1000, dropout_p: Optional[float] = 0.5):
super(Block, self).__init__()
self.fc = nn.Linear(in_features, bottleneck_dim)
self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(dropout_p)
self.dropout_p = dropout_p
def forward(self, x: torch.Tensor) -> torch.Tensor:
f = self.fc(x)
f = self.bn(f)
f = self.relu(f)
f = self.dropout(f)
if self.training:
f.mul_(math.sqrt(1 - self.dropout_p))
return f
class ImageClassifier(ClassfierBase):
r"""
ImageClassifier for AFN.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_classes (int): Number of classes
num_blocks (int, optional): Number of basic blocks. Default: 1
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1000
dropout_p (float, optional): dropout probability. Default: 0.5
"""
def __init__(self, backbone: nn.Module, num_classes: int, num_blocks: Optional[int] = 1,
bottleneck_dim: Optional[int] = 1000, dropout_p: Optional[float] = 0.5, **kwargs):
assert num_blocks >= 1
layers = [nn.Sequential(
Block(backbone.out_features, bottleneck_dim, dropout_p)
)]
for _ in range(num_blocks - 1):
layers.append(Block(bottleneck_dim, bottleneck_dim, dropout_p))
bottleneck = nn.Sequential(*layers)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
# init parameters for bottleneck and head
for m in self.bottleneck.modules():
if isinstance(m, nn.BatchNorm1d):
m.weight.data.normal_(1.0, 0.01)
m.bias.data.fill_(0)
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 0.01)
m.bias.data.normal_(0.0, 0.01)
for m in self.head.modules():
if isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 0.01)
m.bias.data.normal_(0.0, 0.01)
def get_parameters(self, base_lr=1.0) -> List[Dict]:
params = [
{"params": self.backbone.parameters()},
{"params": self.bottleneck.parameters(), "momentum": 0.9},
{"params": self.head.parameters(), "momentum": 0.9},
]
return params
================================================
FILE: tllib/normalization/ibn.py
================================================
"""
Modified from https://github.com/XingangPan/IBN-Net
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import math
import torch
import torch.nn as nn
__all__ = ['resnet18_ibn_a', 'resnet18_ibn_b', 'resnet34_ibn_a', 'resnet34_ibn_b', 'resnet50_ibn_a', 'resnet50_ibn_b',
'resnet101_ibn_a', 'resnet101_ibn_b']
model_urls = {
'resnet18_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth',
'resnet34_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth',
'resnet50_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth',
'resnet101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth',
'resnet18_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_b-bc2f3c11.pth',
'resnet34_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_b-04134c37.pth',
'resnet50_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_b-9ca61e85.pth',
'resnet101_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_b-c55f6dba.pth',
}
class InstanceBatchNorm2d(nn.Module):
r"""Instance-Batch Normalization layer from
`Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (ECCV 2018)
`_.
Given input feature map :math:`f\_input` of dimension :math:`(C,H,W)`, we first split :math:`f\_input` into
two parts along `channel` dimension. They are denoted as :math:`f_1` of dimension :math:`(C_1,H,W)` and
:math:`f_2` of dimension :math:`(C_2,H,W)`, where :math:`C_1+C_2=C`. Then we pass :math:`f_1` and :math:`f_2`
through IN and BN layer, respectively, to get :math:`IN(f_1)` and :math:`BN(f_2)`. Last, we concat them along
`channel` dimension to create :math:`f\_output=concat(IN(f_1), BN(f_2))`.
Args:
planes (int): Number of channels for the input tensor
ratio (float): Ratio of instance normalization in the IBN layer
"""
def __init__(self, planes, ratio=0.5):
super(InstanceBatchNorm2d, self).__init__()
self.half = int(planes * ratio)
self.IN = nn.InstanceNorm2d(self.half, affine=True)
self.BN = nn.BatchNorm2d(planes - self.half)
def forward(self, x):
split = torch.split(x, self.half, 1)
out1 = self.IN(split[0].contiguous())
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
if ibn == 'a':
self.bn1 = InstanceBatchNorm2d(planes)
else:
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.IN = nn.InstanceNorm2d(planes, affine=True) if ibn == 'b' else None
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
if self.IN is not None:
out = self.IN(out)
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
if ibn == 'a':
self.bn1 = InstanceBatchNorm2d(planes)
else:
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.IN = nn.InstanceNorm2d(planes * 4, affine=True) if ibn == 'b' else None
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
if self.IN is not None:
out = self.IN(out)
out = self.relu(out)
return out
class IBNNet(nn.Module):
r"""
IBNNet without fully connected layer
"""
def __init__(self, block, layers, ibn_cfg=('a', 'a', 'a', None)):
self.inplanes = 64
super(IBNNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
if ibn_cfg[0] == 'b':
self.bn1 = nn.InstanceNorm2d(64, affine=True)
else:
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3])
self._out_features = 512 * block.expansion
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, ibn=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes,
None if ibn == 'b' else ibn,
stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes,
None if (ibn == 'b' and i < blocks - 1) else ibn))
return nn.Sequential(*layers)
def forward(self, x):
""""""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
@property
def out_features(self) -> int:
"""The dimension of output features"""
return self._out_features
def resnet18_ibn_a(pretrained=False):
"""Constructs a ResNet-18-IBN-a model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = IBNNet(block=BasicBlock,
layers=[2, 2, 2, 2],
ibn_cfg=('a', 'a', 'a', None))
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_a']), strict=False)
return model
def resnet34_ibn_a(pretrained=False):
"""Constructs a ResNet-34-IBN-a model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = IBNNet(block=BasicBlock,
layers=[3, 4, 6, 3],
ibn_cfg=('a', 'a', 'a', None))
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_a']), strict=False)
return model
def resnet50_ibn_a(pretrained=False):
"""Constructs a ResNet-50-IBN-a model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = IBNNet(block=Bottleneck,
layers=[3, 4, 6, 3],
ibn_cfg=('a', 'a', 'a', None))
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_a']), strict=False)
return model
def resnet101_ibn_a(pretrained=False):
"""Constructs a ResNet-101-IBN-a model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = IBNNet(block=Bottleneck,
layers=[3, 4, 23, 3],
ibn_cfg=('a', 'a', 'a', None))
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_a']), strict=False)
return model
def resnet18_ibn_b(pretrained=False):
"""Constructs a ResNet-18-IBN-b model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = IBNNet(block=BasicBlock,
layers=[2, 2, 2, 2],
ibn_cfg=('b', 'b', None, None))
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_b']), strict=False)
return model
def resnet34_ibn_b(pretrained=False):
"""Constructs a ResNet-34-IBN-b model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = IBNNet(block=BasicBlock,
layers=[3, 4, 6, 3],
ibn_cfg=('b', 'b', None, None))
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_b']), strict=False)
return model
def resnet50_ibn_b(pretrained=False):
"""Constructs a ResNet-50-IBN-b model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = IBNNet(block=Bottleneck,
layers=[3, 4, 6, 3],
ibn_cfg=('b', 'b', None, None))
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_b']), strict=False)
return model
def resnet101_ibn_b(pretrained=False):
"""Constructs a ResNet-101-IBN-b model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = IBNNet(block=Bottleneck,
layers=[3, 4, 23, 3],
ibn_cfg=('b', 'b', None, None))
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_b']), strict=False)
return model
================================================
FILE: tllib/normalization/mixstyle/__init__.py
================================================
"""
Modified from https://github.com/KaiyangZhou/mixstyle-release
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import torch
import torch.nn as nn
class MixStyle(nn.Module):
r"""MixStyle module from `DOMAIN GENERALIZATION WITH MIXSTYLE (ICLR 2021) `_.
Given input :math:`x`, we first compute mean :math:`\mu(x)` and standard deviation :math:`\sigma(x)` across spatial
dimension. Then we permute :math:`x` and get :math:`\tilde{x}`, corresponding mean :math:`\mu(\tilde{x})` and
standard deviation :math:`\sigma(\tilde{x})`. `MixUp` is performed using mean and standard deviation
.. math::
\gamma_{mix} = \lambda\sigma(x) + (1-\lambda)\sigma(\tilde{x})
.. math::
\beta_{mix} = \lambda\mu(x) + (1-\lambda)\mu(\tilde{x})
where :math:`\lambda` is instance-wise weight sampled from `Beta distribution`. MixStyle is then
.. math::
MixStyle(x) = \gamma_{mix}\frac{x-\mu(x)}{\sigma(x)} + \beta_{mix}
Args:
p (float): probability of using MixStyle.
alpha (float): parameter of the `Beta distribution`.
eps (float): scaling parameter to avoid numerical issues.
"""
def __init__(self, p=0.5, alpha=0.1, eps=1e-6):
super().__init__()
self.p = p
self.beta = torch.distributions.Beta(alpha, alpha)
self.eps = eps
self.alpha = alpha
def forward(self, x):
if not self.training:
return x
if random.random() > self.p:
return x
batch_size = x.size(0)
mu = x.mean(dim=[2, 3], keepdim=True)
var = x.var(dim=[2, 3], keepdim=True)
sigma = (var + self.eps).sqrt()
mu, sigma = mu.detach(), sigma.detach()
x_normed = (x - mu) / sigma
interpolation = self.beta.sample((batch_size, 1, 1, 1))
interpolation = interpolation.to(x.device)
# split into two halves and swap the order
perm = torch.arange(batch_size - 1, -1, -1) # inverse index
perm_b, perm_a = perm.chunk(2)
perm_b = perm_b[torch.randperm(batch_size // 2)]
perm_a = perm_a[torch.randperm(batch_size // 2)]
perm = torch.cat([perm_b, perm_a], 0)
mu_perm, sigma_perm = mu[perm], sigma[perm]
mu_mix = mu * interpolation + mu_perm * (1 - interpolation)
sigma_mix = sigma * interpolation + sigma_perm * (1 - interpolation)
return x_normed * sigma_mix + mu_mix
================================================
FILE: tllib/normalization/mixstyle/resnet.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from . import MixStyle
from tllib.vision.models.reid.resnet import ReidResNet
from tllib.vision.models.resnet import ResNet, load_state_dict_from_url, model_urls, BasicBlock, Bottleneck
__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101']
def _resnet_with_mix_style(arch, block, layers, pretrained, progress, mix_layers=None, mix_p=0.5, mix_alpha=0.1,
resnet_class=ResNet, **kwargs):
"""Construct `ResNet` with MixStyle modules. Given any resnet architecture **resnet_class** that contains conv1,
bn1, relu, maxpool, layer1-4, this function define a new class that inherits from **resnet_class** and inserts
MixStyle module during forward pass. Although MixStyle Module can be inserted anywhere, original paper finds it
better to place MixStyle after layer1-3. Our implementation follows this idea, but you are free to modify this
function to try other possibilities.
Args:
arch (str): resnet architecture (resnet50 for example)
block (class): class of resnet block
layers (list): depth list of each block
pretrained (bool): if True, load imagenet pre-trained model parameters
progress (bool): whether or not to display a progress bar to stderr
mix_layers (list): layers to insert MixStyle module after
mix_p (float): probability to activate MixStyle during forward pass
mix_alpha (float): parameter alpha for beta distribution
resnet_class (class): base resnet class to inherit from
"""
if mix_layers is None:
mix_layers = []
available_resnet_class = [ResNet, ReidResNet]
assert resnet_class in available_resnet_class
class ResNetWithMixStyleModule(resnet_class):
def __init__(self, mix_layers, mix_p=0.5, mix_alpha=0.1, *args, **kwargs):
super(ResNetWithMixStyleModule, self).__init__(*args, **kwargs)
self.mixStyleModule = MixStyle(p=mix_p, alpha=mix_alpha)
for layer in mix_layers:
assert layer in ['layer1', 'layer2', 'layer3']
self.apply_layers = mix_layers
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
# turn on relu activation here **except for** reid tasks
if resnet_class != ReidResNet:
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
if 'layer1' in self.apply_layers:
x = self.mixStyleModule(x)
x = self.layer2(x)
if 'layer2' in self.apply_layers:
x = self.mixStyleModule(x)
x = self.layer3(x)
if 'layer3' in self.apply_layers:
x = self.mixStyleModule(x)
x = self.layer4(x)
return x
model = ResNetWithMixStyleModule(mix_layers=mix_layers, mix_p=mix_p, mix_alpha=mix_alpha, block=block,
layers=layers, **kwargs)
if pretrained:
model_dict = model.state_dict()
pretrained_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
# remove keys from pretrained dict that doesn't appear in model dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model.load_state_dict(pretrained_dict, strict=False)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-18 model with MixStyle.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet_with_mix_style('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-34 model with MixStyle.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet_with_mix_style('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-50 model with MixStyle.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet_with_mix_style('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
"""Constructs a ResNet-101 model with MixStyle.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet_with_mix_style('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
================================================
FILE: tllib/normalization/mixstyle/sampler.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import random
import copy
from torch.utils.data.dataset import ConcatDataset
from torch.utils.data.sampler import Sampler
class RandomDomainMultiInstanceSampler(Sampler):
r"""Randomly sample :math:`N` domains, then randomly select :math:`P` instances in each domain, for each instance,
randomly select :math:`K` images to form a mini-batch of size :math:`N\times P\times K`.
Args:
dataset (ConcatDataset): dataset that contains data from multiple domains
batch_size (int): mini-batch size (:math:`N\times P\times K` here)
n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here)
num_instances (int): number of instances to select in each domain (:math:`K` here)
"""
def __init__(self, dataset, batch_size, n_domains_per_batch, num_instances):
super(Sampler, self).__init__()
self.dataset = dataset
self.sample_idxes_per_domain = {}
for idx, (_, _, domain_id) in enumerate(self.dataset):
if domain_id not in self.sample_idxes_per_domain:
self.sample_idxes_per_domain[domain_id] = []
self.sample_idxes_per_domain[domain_id].append(idx)
self.n_domains_in_dataset = len(self.sample_idxes_per_domain)
self.n_domains_per_batch = n_domains_per_batch
assert self.n_domains_in_dataset >= self.n_domains_per_batch
assert batch_size % n_domains_per_batch == 0
self.batch_size_per_domain = batch_size // n_domains_per_batch
assert self.batch_size_per_domain % num_instances == 0
self.num_instances = num_instances
self.num_classes_per_domain = self.batch_size_per_domain // num_instances
self.length = len(list(self.__iter__()))
def __iter__(self):
sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain)
domain_idxes = [idx for idx in range(self.n_domains_in_dataset)]
final_idxes = []
stop_flag = False
while not stop_flag:
selected_domains = random.sample(domain_idxes, self.n_domains_per_batch)
for domain in selected_domains:
sample_idxes = sample_idxes_per_domain[domain]
selected_idxes = self.sample_multi_instances(sample_idxes)
final_idxes.extend(selected_idxes)
for idx in selected_idxes:
sample_idxes_per_domain[domain].remove(idx)
remaining_size = len(sample_idxes_per_domain[domain])
if remaining_size < self.batch_size_per_domain:
stop_flag = True
return iter(final_idxes)
def sample_multi_instances(self, sample_idxes):
idxes_per_cls = {}
for idx in sample_idxes:
_, cls, _ = self.dataset[idx]
if cls not in idxes_per_cls:
idxes_per_cls[cls] = []
idxes_per_cls[cls].append(idx)
cls_list = [cls for cls in idxes_per_cls if len(idxes_per_cls[cls]) >= self.num_instances]
if len(cls_list) < self.num_classes_per_domain:
return random.sample(sample_idxes, self.batch_size_per_domain)
selected_idxes = []
selected_classes = random.sample(cls_list, self.num_classes_per_domain)
for cls in selected_classes:
selected_idxes.extend(random.sample(idxes_per_cls[cls], self.num_instances))
return selected_idxes
def __len__(self):
return self.length
================================================
FILE: tllib/normalization/stochnorm.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
__all__ = ['StochNorm1d', 'StochNorm2d', 'convert_model']
class _StochNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, p=0.5):
super(_StochNorm, self).__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
self.affine = affine
self.track_running_stats = track_running_stats
self.p = p
if self.affine:
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
self.reset_parameters()
def reset_parameters(self):
if self.track_running_stats:
self.running_mean.zero_()
self.running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def _check_input_dim(self, input):
return NotImplemented
def forward(self, input):
self._check_input_dim(input)
if self.training:
z_0 = F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
False, self.momentum, self.eps)
z_1 = F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
True, self.momentum, self.eps)
if input.dim() == 2:
s = torch.from_numpy(
np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1,
self.num_features)).float().cuda()
elif input.dim() == 3:
s = torch.from_numpy(
np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features,
1)).float().cuda()
elif input.dim() == 4:
s = torch.from_numpy(
np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features, 1,
1)).float().cuda()
else:
raise BaseException()
z = (1 - s) * z_0 + s * z_1
else:
z = F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
False, self.momentum, self.eps)
return z
class StochNorm1d(_StochNorm):
r"""Applies Stochastic Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension)
Stochastic Normalization is proposed in `Stochastic Normalization (NIPS 2020) `_
.. math::
\hat{x}_{i,0} = \frac{x_i - \tilde{\mu}}{ \sqrt{\tilde{\sigma} + \epsilon}}
\hat{x}_{i,1} = \frac{x_i - \mu}{ \sqrt{\sigma + \epsilon}}
\hat{x}_i = (1-s)\cdot \hat{x}_{i,0} + s\cdot \hat{x}_{i,1}
y_i = \gamma \hat{x}_i + \beta
where :math:`\mu` and :math:`\sigma` are mean and variance of current mini-batch data.
:math:`\tilde{\mu}` and :math:`\tilde{\sigma}` are current moving statistics of training data.
:math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`.
During training, there are two normalization branches. One uses mean and
variance of current mini-batch data, while the other uses current moving
statistics of the training data as usual batch normalization.
During evaluation, the moving statistics is used for normalization.
Args:
num_features (int): :math:`c` from an expected input of size :math:`(b, c, l)` or :math:`l` from an expected input of size :math:`(b, l)`.
eps (float): A value added to the denominator for numerical stability.
Default: 1e-5
momentum (float): The value used for the running_mean and running_var
computation. Default: 0.1
affine (bool): A boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
track_running_stats (bool): A boolean value that when set to True, this module tracks
the running mean and variance, and when set to False, this module does not
track such statistics, and initializes statistics buffers running_mean and
running_var as None. When these buffers are None, this module always uses
batch statistics in both training and eval modes. Default: True
p (float): The probability to choose the second branch (usual BN). Default: 0.5
Shape:
- Input: :math:`(b, l)` or :math:`(b, c, l)`
- Output: :math:`(b, l)` or :math:`(b, c, l)` (same shape as input)
"""
def _check_input_dim(self, input):
if input.dim() != 2 and input.dim() != 3:
raise ValueError('expected 2D or 3D input (got {}D input)'
.format(input.dim()))
class StochNorm2d(_StochNorm):
r"""
Applies Stochastic Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension)
Stochastic Normalization is proposed in `Stochastic Normalization (NIPS 2020) `_
.. math::
\hat{x}_{i,0} = \frac{x_i - \tilde{\mu}}{ \sqrt{\tilde{\sigma} + \epsilon}}
\hat{x}_{i,1} = \frac{x_i - \mu}{ \sqrt{\sigma + \epsilon}}
\hat{x}_i = (1-s)\cdot \hat{x}_{i,0} + s\cdot \hat{x}_{i,1}
y_i = \gamma \hat{x}_i + \beta
where :math:`\mu` and :math:`\sigma` are mean and variance of current mini-batch data.
:math:`\tilde{\mu}` and :math:`\tilde{\sigma}` are current moving statistics of training data.
:math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`.
During training, there are two normalization branches. One uses mean and
variance of current mini-batch data, while the other uses current moving
statistics of the training data as usual batch normalization.
During evaluation, the moving statistics is used for normalization.
Args:
num_features (int): :math:`c` from an expected input of size :math:`(b, c, h, w)`.
eps (float): A value added to the denominator for numerical stability.
Default: 1e-5
momentum (float): The value used for the running_mean and running_var
computation. Default: 0.1
affine (bool): A boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
track_running_stats (bool): A boolean value that when set to True, this module tracks
the running mean and variance, and when set to False, this module does not
track such statistics, and initializes statistics buffers running_mean and
running_var as None. When these buffers are None, this module always uses
batch statistics in both training and eval modes. Default: True
p (float): The probability to choose the second branch (usual BN). Default: 0.5
Shape:
- Input: :math:`(b, c, h, w)`
- Output: :math:`(b, c, h, w)` (same shape as input)
"""
def _check_input_dim(self, input):
if input.dim() != 4:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
class StochNorm3d(_StochNorm):
r"""
Applies Stochastic Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension)
Stochastic Normalization is proposed in `Stochastic Normalization (NIPS 2020) `_
.. math::
\hat{x}_{i,0} = \frac{x_i - \tilde{\mu}}{ \sqrt{\tilde{\sigma} + \epsilon}}
\hat{x}_{i,1} = \frac{x_i - \mu}{ \sqrt{\sigma + \epsilon}}
\hat{x}_i = (1-s)\cdot \hat{x}_{i,0} + s\cdot \hat{x}_{i,1}
y_i = \gamma \hat{x}_i + \beta
where :math:`\mu` and :math:`\sigma` are mean and variance of current mini-batch data.
:math:`\tilde{\mu}` and :math:`\tilde{\sigma}` are current moving statistics of training data.
:math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`.
During training, there are two normalization branches. One uses mean and
variance of current mini-batch data, while the other uses current moving
statistics of the training data as usual batch normalization.
During evaluation, the moving statistics is used for normalization.
Args:
num_features (int): :math:`c` from an expected input of size :math:`(b, c, d, h, w)`
eps (float): A value added to the denominator for numerical stability.
Default: 1e-5
momentum (float): The value used for the running_mean and running_var
computation. Default: 0.1
affine (bool): A boolean value that when set to ``True``, gives the layer learnable
affine parameters. Default: ``True``
track_running_stats (bool): A boolean value that when set to True, this module tracks
the running mean and variance, and when set to False, this module does not
track such statistics, and initializes statistics buffers running_mean and
running_var as None. When these buffers are None, this module always uses
batch statistics in both training and eval modes. Default: True
p (float): The probability to choose the second branch (usual BN). Default: 0.5
Shape:
- Input: :math:`(b, c, d, h, w)`
- Output: :math:`(b, c, d, h, w)` (same shape as input)
"""
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 4D input (got {}D input)'
.format(input.dim()))
def convert_model(module, p):
"""
Traverses the input module and its child recursively and replaces all
instance of BatchNorm to StochNorm.
Args:
module (torch.nn.Module): The input module needs to be convert to StochNorm model.
p (float): The hyper-parameter for StochNorm layer.
Returns:
The module converted to StochNorm version.
"""
mod = module
for pth_module, stoch_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.BatchNorm3d],
[StochNorm1d,
StochNorm2d,
StochNorm3d]):
if isinstance(module, pth_module):
mod = stoch_module(module.num_features, module.eps, module.momentum, module.affine, p)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_model(child, p))
return mod
================================================
FILE: tllib/ranking/__init__.py
================================================
from .logme import log_maximum_evidence
from .nce import negative_conditional_entropy
from .leep import log_expected_empirical_prediction
from .hscore import h_score
__all__ = ['log_maximum_evidence', 'negative_conditional_entropy', 'log_expected_empirical_prediction', 'h_score']
================================================
FILE: tllib/ranking/hscore.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import numpy as np
from sklearn.covariance import LedoitWolf
__all__ = ['h_score', 'regularized_h_score']
def h_score(features: np.ndarray, labels: np.ndarray):
r"""
H-score in `An Information-theoretic Approach to Transferability in Task Transfer Learning (ICIP 2019)
`_.
The H-Score :math:`\mathcal{H}` can be described as:
.. math::
\mathcal{H}=\operatorname{tr}\left(\operatorname{cov}(f)^{-1} \operatorname{cov}\left(\mathbb{E}[f \mid y]\right)\right)
where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector
Args:
features (np.ndarray):features extracted by pre-trained model.
labels (np.ndarray): groud-truth labels.
Shape:
- features: (N, F), with number of samples N and feature dimension F.
- labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.
- score: scalar.
"""
f = features
y = labels
covf = np.cov(f, rowvar=False)
C = int(y.max() + 1)
g = np.zeros_like(f)
for i in range(C):
Ef_i = np.mean(f[y == i, :], axis=0)
g[y == i] = Ef_i
covg = np.cov(g, rowvar=False)
score = np.trace(np.dot(np.linalg.pinv(covf, rcond=1e-15), covg))
return score
def regularized_h_score(features: np.ndarray, labels: np.ndarray):
r"""
Regularized H-score in `Newer is not always better: Rethinking transferability metrics, their peculiarities, stability and performance (NeurIPS 2021)
`_.
The regularized H-Score :math:`\mathcal{H}_{\alpha}` can be described as:
.. math::
\mathcal{H}_{\alpha}=\operatorname{tr}\left(\operatorname{cov}_{\alpha}(f)^{-1}\left(1-\alpha \right)\operatorname{cov}\left(\mathbb{E}[f \mid y]\right)\right)
where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector and :math:`\operatorname{cov}_{\alpha}` the Ledoit-Wolf
covariance estimator with shrinkage parameter :math:`\alpha`
Args:
features (np.ndarray):features extracted by pre-trained model.
labels (np.ndarray): groud-truth labels.
Shape:
- features: (N, F), with number of samples N and feature dimension F.
- labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.
- score: scalar.
"""
f = features.astype('float64')
f = f - np.mean(f, axis=0, keepdims=True) # Center the features for correct Ledoit-Wolf Estimation
y = labels
C = int(y.max() + 1)
g = np.zeros_like(f)
cov = LedoitWolf(assume_centered=False).fit(f)
alpha = cov.shrinkage_
covf_alpha = cov.covariance_
for i in range(C):
Ef_i = np.mean(f[y == i, :], axis=0)
g[y == i] = Ef_i
covg = np.cov(g, rowvar=False)
score = np.trace(np.dot(np.linalg.pinv(covf_alpha, rcond=1e-15), (1 - alpha) * covg))
return score
================================================
FILE: tllib/ranking/leep.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import numpy as np
__all__ = ['log_expected_empirical_prediction']
def log_expected_empirical_prediction(predictions: np.ndarray, labels: np.ndarray):
r"""
Log Expected Empirical Prediction in `LEEP: A New Measure to
Evaluate Transferability of Learned Representations (ICML 2020)
`_.
The LEEP :math:`\mathcal{T}` can be described as:
.. math::
\mathcal{T}=\mathbb{E}\log \left(\sum_{z \in \mathcal{C}_s} \hat{P}\left(y \mid z\right) \theta\left(y \right)_{z}\right)
where :math:`\theta\left(y\right)_{z}` is the predictions of pre-trained model on source category, :math:`\hat{P}\left(y \mid z\right)` is the empirical conditional distribution estimated by prediction and ground-truth label.
Args:
predictions (np.ndarray): predictions of pre-trained model.
labels (np.ndarray): groud-truth labels.
Shape:
- predictions: (N, :math:`C_s`), with number of samples N and source class number :math:`C_s`.
- labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.
- score: scalar
"""
N, C_s = predictions.shape
labels = labels.reshape(-1)
C_t = int(np.max(labels) + 1)
normalized_prob = predictions / float(N)
joint = np.zeros((C_t, C_s), dtype=float) # placeholder for joint distribution over (y, z)
for i in range(C_t):
this_class = normalized_prob[labels == i]
row = np.sum(this_class, axis=0)
joint[i] = row
p_target_given_source = (joint / joint.sum(axis=0, keepdims=True)).T # P(y | z)
empirical_prediction = predictions @ p_target_given_source
empirical_prob = np.array([predict[label] for predict, label in zip(empirical_prediction, labels)])
score = np.mean(np.log(empirical_prob))
return score
================================================
FILE: tllib/ranking/logme.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import numpy as np
from numba import njit
__all__ = ['log_maximum_evidence']
def log_maximum_evidence(features: np.ndarray, targets: np.ndarray, regression=False, return_weights=False):
r"""
Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models
for Transfer Learning (ICML 2021) `_.
Args:
features (np.ndarray): feature matrix from pre-trained model.
targets (np.ndarray): targets labels/values.
regression (bool, optional): whether to apply in regression setting. (Default: False)
return_weights (bool, optional): whether to return bayesian weight. (Default: False)
Shape:
- features: (N, F) with element in [0, :math:`C_t`) and feature dimension F, where :math:`C_t` denotes the number of target class
- targets: (N, ) or (N, C), with C regression-labels.
- weights: (F, :math:`C_t`).
- score: scalar.
"""
f = features.astype(np.float64)
y = targets
if regression:
y = targets.astype(np.float64)
fh = f
f = f.transpose()
D, N = f.shape
v, s, vh = np.linalg.svd(f @ fh, full_matrices=True)
evidences = []
weights = []
if regression:
C = y.shape[1]
for i in range(C):
y_ = y[:, i]
evidence, weight = each_evidence(y_, f, fh, v, s, vh, N, D)
evidences.append(evidence)
weights.append(weight)
else:
C = int(y.max() + 1)
for i in range(C):
y_ = (y == i).astype(np.float64)
evidence, weight = each_evidence(y_, f, fh, v, s, vh, N, D)
evidences.append(evidence)
weights.append(weight)
score = np.mean(evidences)
weights = np.vstack(weights)
if return_weights:
return score, weights
else:
return score
@njit
def each_evidence(y_, f, fh, v, s, vh, N, D):
"""
compute the maximum evidence for each class
"""
alpha = 1.0
beta = 1.0
lam = alpha / beta
tmp = (vh @ (f @ y_))
for _ in range(11):
# should converge after at most 10 steps
# typically converge after two or three steps
gamma = (s / (s + lam)).sum()
m = v @ (tmp * beta / (alpha + beta * s))
alpha_de = (m * m).sum()
alpha = gamma / alpha_de
beta_de = ((y_ - fh @ m) ** 2).sum()
beta = (N - gamma) / beta_de
new_lam = alpha / beta
if np.abs(new_lam - lam) / lam < 0.01:
break
lam = new_lam
evidence = D / 2.0 * np.log(alpha) \
+ N / 2.0 * np.log(beta) \
- 0.5 * np.sum(np.log(alpha + beta * s)) \
- beta / 2.0 * beta_de \
- alpha / 2.0 * alpha_de \
- N / 2.0 * np.log(2 * np.pi)
return evidence / N, m
================================================
FILE: tllib/ranking/nce.py
================================================
"""
@author: Yong Liu
@contact: liuyong1095556447@163.com
"""
import numpy as np
__all__ = ['negative_conditional_entropy']
def negative_conditional_entropy(source_labels: np.ndarray, target_labels: np.ndarray):
r"""
Negative Conditional Entropy in `Transferability and Hardness of Supervised
Classification Tasks (ICCV 2019) `_.
The NCE :math:`\mathcal{H}` can be described as:
.. math::
\mathcal{H}=-\sum_{y \in \mathcal{C}_t} \sum_{z \in \mathcal{C}_s} \hat{P}(y, z) \log \frac{\hat{P}(y, z)}{\hat{P}(z)}
where :math:`\hat{P}(z)` is the empirical distribution and :math:`\hat{P}\left(y \mid z\right)` is the empirical
conditional distribution estimated by source and target label.
Args:
source_labels (np.ndarray): predicted source labels.
target_labels (np.ndarray): groud-truth target labels.
Shape:
- source_labels: (N, ) elements in [0, :math:`C_s`), with source class number :math:`C_s`.
- target_labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.
"""
C_t = int(np.max(target_labels) + 1)
C_s = int(np.max(source_labels) + 1)
N = len(source_labels)
joint = np.zeros((C_t, C_s), dtype=float) # placeholder for the joint distribution, shape [C_t, C_s]
for s, t in zip(source_labels, target_labels):
s = int(s)
t = int(t)
joint[t, s] += 1.0 / N
p_z = joint.sum(axis=0, keepdims=True)
p_target_given_source = (joint / p_z).T # P(y | z), shape [C_s, C_t]
mask = p_z.reshape(-1) != 0 # valid Z, shape [C_s]
p_target_given_source = p_target_given_source[mask] + 1e-20 # remove NaN where p(z) = 0, add 1e-20 to avoid log (0)
entropy_y_given_z = np.sum(- p_target_given_source * np.log(p_target_given_source), axis=1, keepdims=True)
conditional_entropy = np.sum(entropy_y_given_z * p_z.reshape((-1, 1))[mask])
return -conditional_entropy
================================================
FILE: tllib/ranking/transrate.py
================================================
"""
@author: Louis Fouquet
@contact: louisfouquet75@gmail.com
"""
import numpy as np
__all__ = ['transrate']
def coding_rate(features: np.ndarray, eps=1e-4):
f = features
n, d = f.shape
(_, rate) = np.linalg.slogdet((np.eye(d) + 1 / (n * eps) * f.transpose() @ f))
return 0.5 * rate
def transrate(features: np.ndarray, labels: np.ndarray, eps=1e-4):
r"""
TransRate in `Frustratingly easy transferability estimation (ICML 2022)
`_.
The TransRate :math:`TrR` can be described as:
.. math::
TrR= R\left(f, \espilon \right) - R\left(f, \espilon \mid y \right)
where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector,
:math:`R` is the coding rate with distortion rate :math:`\epsilon`
Args:
features (np.ndarray):features extracted by pre-trained model.
labels (np.ndarray): groud-truth labels.
eps (float, optional): distortion rare (Default: 1e-4)
Shape:
- features: (N, F), with number of samples N and feature dimension F.
- labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`.
- score: scalar.
"""
f = features
y = labels
f = f - np.mean(f, axis=0, keepdims=True)
Rf = coding_rate(f, eps)
Rfy = 0.0
C = int(y.max() + 1)
for i in range(C):
Rfy += coding_rate(f[(y == i).flatten()], eps)
return Rf - Rfy / C
================================================
FILE: tllib/regularization/__init__.py
================================================
from .bss import *
from .co_tuning import *
from .delta import *
from .bi_tuning import *
from .knowledge_distillation import *
__all__ = ['bss', 'co_tuning', 'delta', 'bi_tuning', 'knowledge_distillation']
================================================
FILE: tllib/regularization/bi_tuning.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
import torch.nn as nn
from torch.nn.functional import normalize
from tllib.modules.classifier import Classifier as ClassifierBase
class Classifier(ClassifierBase):
"""Classifier class for Bi-Tuning.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_classes (int): Number of classes
projection_dim (int, optional): Dimension of the projector head. Default: 128
finetune (bool): Whether finetune the classifier or train from scratch. Default: True
.. note::
The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy
by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`.
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
In the training mode,
- y: classifier's predictions
- z: projector's predictions
- hn: normalized features after `bottleneck` layer and before `head` layer
In the eval mode,
- y: classifier's predictions
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- y: (minibatch, `num_classes`)
- z: (minibatch, `projection_dim`)
- hn: (minibatch, `features_dim`)
"""
def __init__(self, backbone: nn.Module, num_classes: int, projection_dim=128, finetune=True, pool_layer=None):
head = nn.Linear(backbone.out_features, num_classes)
head.weight.data.normal_(0, 0.01)
head.bias.data.fill_(0.0)
super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune,
pool_layer=pool_layer)
self.projector = nn.Linear(backbone.out_features, projection_dim)
self.projection_dim = projection_dim
def forward(self, x: torch.Tensor):
batch_size = x.shape[0]
h = self.backbone(x)
h = self.pool_layer(h)
h = self.bottleneck(h)
y = self.head(h)
z = normalize(self.projector(h), dim=1)
hn = torch.cat([h, torch.ones(batch_size, 1, dtype=torch.float).to(h.device)], dim=1)
hn = normalize(hn, dim=1)
if self.training:
return y, z, hn
else:
return y
def get_parameters(self, base_lr=1.0):
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
{"params": self.projector.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
]
return params
class BiTuning(nn.Module):
"""
Bi-Tuning Module in `Bi-tuning of Pre-trained Representations `_.
Args:
encoder_q (Classifier): Query encoder.
encoder_k (Classifier): Key encoder.
num_classes (int): Number of classes
K (int): Queue size. Default: 40
m (float): Momentum coefficient. Default: 0.999
T (float): Temperature. Default: 0.07
Inputs:
- im_q (tensor): input data fed to `encoder_q`
- im_k (tensor): input data fed to `encoder_k`
- labels (tensor): classification labels of input data
Outputs: y_q, logits_z, logits_y, labels_c
- y_q: query classifier's predictions
- logits_z: projector's predictions on both positive and negative samples
- logits_y: classifier's predictions on both positive and negative samples
- labels_c: contrastive labels
Shape:
- im_q, im_k: (minibatch, *) where * means, any number of additional dimensions
- labels: (minibatch, )
- y_q: (minibatch, `num_classes`)
- logits_z: (minibatch, 1 + `num_classes` x `K`, `projection_dim`)
- logits_y: (minibatch, 1 + `num_classes` x `K`, `num_classes`)
- labels_c: (minibatch, 1 + `num_classes` x `K`)
"""
def __init__(self, encoder_q: Classifier, encoder_k: Classifier, num_classes, K=40, m=0.999, T=0.07):
super(BiTuning, self).__init__()
self.K = K
self.m = m
self.T = T
self.num_classes = num_classes
# create the encoders
# num_classes is the output fc dimension
self.encoder_q = encoder_q
self.encoder_k = encoder_k
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
# create the queue
self.register_buffer("queue_h", torch.randn(encoder_q.features_dim + 1, num_classes, K))
self.register_buffer("queue_z", torch.randn(encoder_q.projection_dim, num_classes, K))
self.queue_h = normalize(self.queue_h, dim=0)
self.queue_z = normalize(self.queue_z, dim=0)
self.register_buffer("queue_ptr", torch.zeros(num_classes, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, h, z, label):
batch_size = h.shape[0]
assert self.K % batch_size == 0 # for simplicity
ptr = int(self.queue_ptr[label])
# replace the keys at ptr (dequeue and enqueue)
self.queue_h[:, label, ptr: ptr + batch_size] = h.T
self.queue_z[:, label, ptr: ptr + batch_size] = z.T
# move pointer
self.queue_ptr[label] = (ptr + batch_size) % self.K
def forward(self, im_q, im_k, labels):
batch_size = im_q.size(0)
device = im_q.device
# compute query features
y_q, z_q, h_q = self.encoder_q(im_q)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
y_k, z_k, h_k = self.encoder_k(im_k)
# compute logits for projection z
# current positive logits: Nx1
logits_z_cur = torch.einsum('nc,nc->n', [z_q, z_k]).unsqueeze(-1)
queue_z = self.queue_z.clone().detach().to(device)
# positive logits: N x K
logits_z_pos = torch.Tensor([]).to(device)
# negative logits: N x ((C-1) x K)
logits_z_neg = torch.Tensor([]).to(device)
for i in range(batch_size):
c = labels[i]
pos_samples = queue_z[:, c, :] # D x K
neg_samples = torch.cat([queue_z[:, 0: c, :], queue_z[:, c + 1:, :]], dim=1).flatten(
start_dim=1) # D x ((C-1)xK)
ith_pos = torch.einsum('nc,ck->nk', [z_q[i: i + 1], pos_samples]) # 1 x D
ith_neg = torch.einsum('nc,ck->nk', [z_q[i: i + 1], neg_samples]) # 1 x ((C-1)xK)
logits_z_pos = torch.cat((logits_z_pos, ith_pos), dim=0)
logits_z_neg = torch.cat((logits_z_neg, ith_neg), dim=0)
self._dequeue_and_enqueue(h_k[i:i + 1], z_k[i:i + 1], labels[i])
logits_z = torch.cat([logits_z_cur, logits_z_pos, logits_z_neg], dim=1) # Nx(1+C*K)
# apply temperature
logits_z /= self.T
logits_z = nn.LogSoftmax(dim=1)(logits_z)
# compute logits for classification y
w = torch.cat([self.encoder_q.head.weight.data, self.encoder_q.head.bias.data.unsqueeze(-1)], dim=1)
w = normalize(w, dim=1) # C x F
# current positive logits: Nx1
logits_y_cur = torch.einsum('nk,kc->nc', [h_q, w.T]) # N x C
queue_y = self.queue_h.clone().detach().to(device).flatten(start_dim=1).T # (C * K) x F
logits_y_queue = torch.einsum('nk,kc->nc', [queue_y, w.T]).reshape(self.num_classes, -1,
self.num_classes) # C x K x C
logits_y = torch.Tensor([]).to(device)
for i in range(batch_size):
c = labels[i]
# calculate the ith sample in the batch
cur_sample = logits_y_cur[i:i + 1, c] # 1
pos_samples = logits_y_queue[c, :, c] # K
neg_samples = torch.cat([logits_y_queue[0: c, :, c], logits_y_queue[c + 1:, :, c]], dim=0).view(
-1) # (C-1)*K
ith = torch.cat([cur_sample, pos_samples, neg_samples]) # 1+C*K
logits_y = torch.cat([logits_y, ith.unsqueeze(dim=0)], dim=0)
logits_y /= self.T
logits_y = nn.LogSoftmax(dim=1)(logits_y)
# contrastive labels
labels_c = torch.zeros([batch_size, self.K * self.num_classes + 1]).to(device)
labels_c[:, 0:self.K + 1].fill_(1.0 / (self.K + 1))
return y_q, logits_z, logits_y, labels_c
================================================
FILE: tllib/regularization/bss.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import torch
import torch.nn as nn
__all__ = ['BatchSpectralShrinkage']
class BatchSpectralShrinkage(nn.Module):
r"""
The regularization term in `Catastrophic Forgetting Meets Negative Transfer:
Batch Spectral Shrinkage for Safe Transfer Learning (NIPS 2019) `_.
The BSS regularization of feature matrix :math:`F` can be described as:
.. math::
L_{bss}(F) = \sum_{i=1}^{k} \sigma_{-i}^2 ,
where :math:`k` is the number of singular values to be penalized, :math:`\sigma_{-i}` is the :math:`i`-th smallest singular value of feature matrix :math:`F`.
All the singular values of feature matrix :math:`F` are computed by `SVD`:
.. math::
F = U\Sigma V^T,
where the main diagonal elements of the singular value matrix :math:`\Sigma` is :math:`[\sigma_1, \sigma_2, ..., \sigma_b]`.
Args:
k (int): The number of singular values to be penalized. Default: 1
Shape:
- Input: :math:`(b, |\mathcal{f}|)` where :math:`b` is the batch size and :math:`|\mathcal{f}|` is feature dimension.
- Output: scalar.
"""
def __init__(self, k=1):
super(BatchSpectralShrinkage, self).__init__()
self.k = k
def forward(self, feature):
result = 0
u, s, v = torch.svd(feature.t())
num = s.size(0)
for i in range(self.k):
result += torch.pow(s[num-1-i], 2)
return result
================================================
FILE: tllib/regularization/co_tuning.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
from typing import Tuple, Optional, List, Dict
import os
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import tqdm
from .lwf import Classifier as ClassifierBase
__all__ = ['Classifier', 'CoTuningLoss', 'Relationship']
class CoTuningLoss(nn.Module):
"""
The Co-Tuning loss in `Co-Tuning for Transfer Learning (NIPS 2020)
`_.
Inputs:
- input: p(y_s) predicted by source classifier.
- target: p(y_s|y_t), where y_t is the ground truth class label in target dataset.
Shape:
- input: (b, N_p), where b is the batch size and N_p is the number of classes in source dataset
- target: (b, N_p), where b is the batch size and N_p is the number of classes in source dataset
- Outputs: scalar.
"""
def __init__(self):
super(CoTuningLoss, self).__init__()
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
y = - target * F.log_softmax(input, dim=-1)
y = torch.mean(torch.sum(y, dim=-1))
return y
class Relationship(object):
"""Learns the category relationship p(y_s|y_t) between source dataset and target dataset.
Args:
data_loader (torch.utils.data.DataLoader): A data loader of target dataset.
classifier (torch.nn.Module): A classifier for Co-Tuning.
device (torch.nn.Module): The device to run classifier.
cache (str, optional): Path to find and save the relationship file.
"""
def __init__(self, data_loader, classifier, device, cache=None):
super(Relationship, self).__init__()
self.data_loader = data_loader
self.classifier = classifier
self.device = device
if cache is None or not os.path.exists(cache):
source_predictions, target_labels = self.collect_labels()
self.relationship = self.get_category_relationship(source_predictions, target_labels)
if cache is not None:
np.save(cache, self.relationship)
else:
self.relationship = np.load(cache)
def __getitem__(self, category):
return self.relationship[category]
def collect_labels(self):
"""
Collects predictions of target dataset by source model and corresponding ground truth class labels.
Returns:
- source_probabilities, [N, N_p], where N_p is the number of classes in source dataset
- target_labels, [N], where 0 <= each number < N_t, and N_t is the number of classes in target dataset
"""
print("Collecting labels to calculate relationship")
source_predictions = []
target_labels = []
self.classifier.eval()
with torch.no_grad():
for i, (x, label) in enumerate(tqdm.tqdm(self.data_loader)):
x = x.to(self.device)
y_s = self.classifier(x)
source_predictions.append(F.softmax(y_s, dim=1).detach().cpu().numpy())
target_labels.append(label)
return np.concatenate(source_predictions, 0), np.concatenate(target_labels, 0)
def get_category_relationship(self, source_probabilities, target_labels):
"""
The direct approach of learning category relationship p(y_s | y_t).
Args:
source_probabilities (numpy.array): [N, N_p], where N_p is the number of classes in source dataset
target_labels (numpy.array): [N], where 0 <= each number < N_t, and N_t is the number of classes in target dataset
Returns:
Conditional probability, [N_c, N_p] matrix representing the conditional probability p(pre-trained class | target_class)
"""
N_t = np.max(target_labels) + 1 # the number of target classes
conditional = []
for i in range(N_t):
this_class = source_probabilities[target_labels == i]
average = np.mean(this_class, axis=0, keepdims=True)
conditional.append(average)
return np.concatenate(conditional)
class Classifier(ClassifierBase):
"""A Classifier used in `Co-Tuning for Transfer Learning (NIPS 2020)
`_..
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data.
num_classes (int): Number of classes.
head_source (torch.nn.Module): Classifier head of source model.
head_target (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default
finetune (bool): Whether finetune the classifier or train from scratch. Default: True
Inputs:
- x (tensor): input data fed to backbone
Outputs:
- y_s: predictions of source classifier head
- y_t: predictions of target classifier head
Shape:
- Inputs: (b, *) where b is the batch size and * means any number of additional dimensions
- y_s: (b, N), where b is the batch size and N is the number of classes
- y_t: (b, N), where b is the batch size and N is the number of classes
"""
def __init__(self, backbone: nn.Module, num_classes: int, head_source, **kwargs):
super(Classifier, self).__init__(backbone, num_classes, head_source, **kwargs)
def get_parameters(self, base_lr=1.0) -> List[Dict]:
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.head_source.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head_target.parameters(), "lr": 1.0 * base_lr},
]
return params
================================================
FILE: tllib/regularization/delta.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import torch
import torch.nn as nn
import functools
from collections import OrderedDict
class L2Regularization(nn.Module):
r"""The L2 regularization of parameters :math:`w` can be described as:
.. math::
{\Omega} (w) = \dfrac{1}{2} \Vert w\Vert_2^2 ,
Args:
model (torch.nn.Module): The model to apply L2 penalty.
Shape:
- Output: scalar.
"""
def __init__(self, model: nn.Module):
super(L2Regularization, self).__init__()
self.model = model
def forward(self):
output = 0.0
for param in self.model.parameters():
output += 0.5 * torch.norm(param) ** 2
return output
class SPRegularization(nn.Module):
r"""
The SP (Starting Point) regularization from `Explicit inductive bias for transfer learning with convolutional networks
(ICML 2018) `_
The SP regularization of parameters :math:`w` can be described as:
.. math::
{\Omega} (w) = \dfrac{1}{2} \Vert w-w_0\Vert_2^2 ,
where :math:`w_0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning.
Args:
source_model (torch.nn.Module): The source (starting point) model.
target_model (torch.nn.Module): The target (fine-tuning) model.
Shape:
- Output: scalar.
"""
def __init__(self, source_model: nn.Module, target_model: nn.Module):
super(SPRegularization, self).__init__()
self.target_model = target_model
self.source_weight = {}
for name, param in source_model.named_parameters():
self.source_weight[name] = param.detach()
def forward(self):
output = 0.0
for name, param in self.target_model.named_parameters():
output += 0.5 * torch.norm(param - self.source_weight[name]) ** 2
return output
class BehavioralRegularization(nn.Module):
r"""
The behavioral regularization from `DELTA:DEep Learning Transfer using Feature Map with Attention
for convolutional networks (ICLR 2019) `_
It can be described as:
.. math::
{\Omega} (w) = \sum_{j=1}^{N} \Vert FM_j(w, \boldsymbol x)-FM_j(w^0, \boldsymbol x)\Vert_2^2 ,
where :math:`w^0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning,
:math:`FM_j(w, \boldsymbol x)` is feature maps generated from the :math:`j`-th layer of the model parameterized with :math:`w`, given the input :math:`\boldsymbol x`.
Inputs:
layer_outputs_source (OrderedDict): The dictionary for source model, where the keys are layer names and the values are feature maps correspondingly.
layer_outputs_target (OrderedDict): The dictionary for target model, where the keys are layer names and the values are feature maps correspondingly.
Shape:
- Output: scalar.
"""
def __init__(self):
super(BehavioralRegularization, self).__init__()
def forward(self, layer_outputs_source, layer_outputs_target):
output = 0.0
for fm_src, fm_tgt in zip(layer_outputs_source.values(), layer_outputs_target.values()):
output += 0.5 * (torch.norm(fm_tgt - fm_src.detach()) ** 2)
return output
class AttentionBehavioralRegularization(nn.Module):
r"""
The behavioral regularization with attention from `DELTA:DEep Learning Transfer using Feature Map with Attention
for convolutional networks (ICLR 2019) `_
It can be described as:
.. math::
{\Omega} (w) = \sum_{j=1}^{N} W_j(w) \Vert FM_j(w, \boldsymbol x)-FM_j(w^0, \boldsymbol x)\Vert_2^2 ,
where
:math:`w^0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning.
:math:`FM_j(w, \boldsymbol x)` is feature maps generated from the :math:`j`-th layer of the model parameterized with :math:`w`, given the input :math:`\boldsymbol x`.
:math:`W_j(w)` is the channel attention of the :math:`j`-th layer of the model parameterized with :math:`w`.
Args:
channel_attention (list): The channel attentions of feature maps generated by each selected layer. For the layer with C channels, the channel attention is a tensor of shape [C].
Inputs:
layer_outputs_source (OrderedDict): The dictionary for source model, where the keys are layer names and the values are feature maps correspondingly.
layer_outputs_target (OrderedDict): The dictionary for target model, where the keys are layer names and the values are feature maps correspondingly.
Shape:
- Output: scalar.
"""
def __init__(self, channel_attention):
super(AttentionBehavioralRegularization, self).__init__()
self.channel_attention = channel_attention
def forward(self, layer_outputs_source, layer_outputs_target):
output = 0.0
for i, (fm_src, fm_tgt) in enumerate(zip(layer_outputs_source.values(), layer_outputs_target.values())):
b, c, h, w = fm_src.shape
fm_src = fm_src.reshape(b, c, h * w)
fm_tgt = fm_tgt.reshape(b, c, h * w)
distance = torch.norm(fm_tgt - fm_src.detach(), 2, 2)
distance = c * torch.mul(self.channel_attention[i], distance ** 2) / (h * w)
output += 0.5 * torch.sum(distance)
return output
def get_attribute(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
class IntermediateLayerGetter:
r"""
Wraps a model to get intermediate output values of selected layers.
Args:
model (torch.nn.Module): The model to collect intermediate layer feature maps.
return_layers (list): The names of selected modules to return the output.
keep_output (bool): If True, `model_output` contains the final model's output, else return None. Default: True
Returns:
- An OrderedDict of intermediate outputs. The keys are selected layer names in `return_layers` and the values are the feature map outputs. The order is the same as `return_layers`.
- The model's final output. If `keep_output` is False, return None.
"""
def __init__(self, model, return_layers, keep_output=True):
self._model = model
self.return_layers = return_layers
self.keep_output = keep_output
def __call__(self, *args, **kwargs):
ret = OrderedDict()
handles = []
for name in self.return_layers:
layer = get_attribute(self._model, name)
def hook(module, input, output, name=name):
ret[name] = output
try:
h = layer.register_forward_hook(hook)
except AttributeError as e:
raise AttributeError(f'Module {name} not found')
handles.append(h)
if self.keep_output:
output = self._model(*args, **kwargs)
else:
self._model(*args, **kwargs)
output = None
for h in handles:
h.remove()
return ret, output
================================================
FILE: tllib/regularization/knowledge_distillation.py
================================================
import torch.nn as nn
import torch.nn.functional as F
class KnowledgeDistillationLoss(nn.Module):
"""Knowledge Distillation Loss.
Args:
T (double): Temperature. Default: 1.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'batchmean'``
Inputs:
- y_student (tensor): logits output of the student
- y_teacher (tensor): logits output of the teacher
Shape:
- y_student: (minibatch, `num_classes`)
- y_teacher: (minibatch, `num_classes`)
"""
def __init__(self, T=1., reduction='batchmean'):
super(KnowledgeDistillationLoss, self).__init__()
self.T = T
self.kl = nn.KLDivLoss(reduction=reduction)
def forward(self, y_student, y_teacher):
""""""
return self.kl(F.log_softmax(y_student / self.T, dim=-1), F.softmax(y_teacher / self.T, dim=-1))
================================================
FILE: tllib/regularization/lwf.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, List, Dict
import torch
import torch.nn as nn
import tqdm
def collect_pretrain_labels(data_loader, classifier, device):
source_predictions = []
classifier.eval()
with torch.no_grad():
for i, (x, label) in enumerate(tqdm.tqdm(data_loader)):
x = x.to(device)
y_s = classifier(x)
source_predictions.append(y_s.detach().cpu())
return torch.cat(source_predictions, dim=0)
class Classifier(nn.Module):
"""A Classifier used in `Learning Without Forgetting (ECCV 2016)
`_..
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data.
num_classes (int): Number of classes.
head_source (torch.nn.Module): Classifier head of source model.
head_target (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default
finetune (bool): Whether finetune the classifier or train from scratch. Default: True
Inputs:
- x (tensor): input data fed to backbone
Outputs:
- y_s: predictions of source classifier head
- y_t: predictions of target classifier head
Shape:
- Inputs: (b, *) where b is the batch size and * means any number of additional dimensions
- y_s: (b, N), where b is the batch size and N is the number of classes
- y_t: (b, N), where b is the batch size and N is the number of classes
"""
def __init__(self, backbone: nn.Module, num_classes: int, head_source,
head_target: Optional[nn.Module] = None, bottleneck: Optional[nn.Module] = None,
bottleneck_dim: Optional[int] = -1, finetune=True, pool_layer=None):
super(Classifier, self).__init__()
self.backbone = backbone
self.num_classes = num_classes
if pool_layer is None:
self.pool_layer = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
else:
self.pool_layer = pool_layer
if bottleneck is None:
self.bottleneck = nn.Identity()
self._features_dim = backbone.out_features
else:
self.bottleneck = bottleneck
assert bottleneck_dim > 0
self._features_dim = bottleneck_dim
self.head_source = head_source
if head_target is None:
self.head_target = nn.Linear(self._features_dim, num_classes)
else:
self.head_target = head_target
self.finetune = finetune
@property
def features_dim(self) -> int:
"""The dimension of features before the final `head` layer"""
return self._features_dim
def forward(self, x: torch.Tensor):
""""""
f = self.backbone(x)
f = self.pool_layer(f)
y_s = self.head_source(f)
y_t = self.head_target(self.bottleneck(f))
if self.training:
return y_s, y_t
else:
return y_t
def get_parameters(self, base_lr=1.0) -> List[Dict]:
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
# {"params": self.head_source.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head_target.parameters(), "lr": 1.0 * base_lr},
]
return params
================================================
FILE: tllib/reweight/__init__.py
================================================
================================================
FILE: tllib/reweight/groupdro.py
================================================
"""
Modified from https://github.com/facebookresearch/DomainBed
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
class AutomaticUpdateDomainWeightModule(object):
r"""
Maintaining group weight based on loss history of all domains according
to `Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case
Generalization (ICLR 2020) `_.
Suppose we have :math:`N` domains. During each iteration, we first calculate unweighted loss among all
domains, resulting in :math:`loss\in R^N`. Then we update domain weight by
.. math::
w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N]
where :math:`\eta` is the hyper parameter which ensures smoother change of weight.
As :math:`w \in R^N` denotes a distribution, we `normalize`
:math:`w` by its sum. At last, weighted loss is calculated as our objective
.. math::
objective = \sum_{k=1}^N w_k * loss_k
Args:
num_domains (int): The number of source domains.
eta (float): Hyper parameter eta.
device (torch.device): The device to run on.
"""
def __init__(self, num_domains: int, eta: float, device):
self.domain_weight = torch.ones(num_domains).to(device) / num_domains
self.eta = eta
def get_domain_weight(self, sampled_domain_idxes):
"""Get domain weight to calculate final objective.
Inputs:
- sampled_domain_idxes (list): sampled domain indexes in current mini-batch
Shape:
- sampled_domain_idxes: :math:`(D, )` where D means the number of sampled domains in current mini-batch
- Outputs: :math:`(D, )`
"""
domain_weight = self.domain_weight[sampled_domain_idxes]
domain_weight = domain_weight / domain_weight.sum()
return domain_weight
def update(self, sampled_domain_losses: torch.Tensor, sampled_domain_idxes):
"""Update domain weight using loss of current mini-batch.
Inputs:
- sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch
- sampled_domain_idxes (list): sampled domain indexes in current mini-batch
Shape:
- sampled_domain_losses: :math:`(D, )` where D means the number of sampled domains in current mini-batch
- sampled_domain_idxes: :math:`(D, )`
"""
sampled_domain_losses = sampled_domain_losses.detach()
for loss, idx in zip(sampled_domain_losses, sampled_domain_idxes):
self.domain_weight[idx] *= (self.eta * loss).exp()
================================================
FILE: tllib/reweight/iwan.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from typing import Optional, List, Dict
import torch
import torch.nn as nn
from tllib.modules.classifier import Classifier as ClassifierBase
class ImportanceWeightModule(object):
r"""
Calculating class weight based on the output of discriminator.
Introduced by `Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018) `_
Args:
discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of features.
Its input shape is :math:`(N, F)` and output shape is :math:`(N, 1)`
partial_classes_index (list[int], optional): The index of partial classes. Note that this parameter is \
just for debugging, since in real-world dataset, we have no access to the index of partial classes. \
Default: None.
Examples::
>>> domain_discriminator = DomainDiscriminator(1024, 1024)
>>> importance_weight_module = ImportanceWeightModule(domain_discriminator)
>>> num_iterations = 10000
>>> for _ in range(num_iterations):
>>> # feature from source domain
>>> f_s = torch.randn(32, 1024)
>>> # importance weights for source instance
>>> w_s = importance_weight_module.get_importance_weight(f_s)
"""
def __init__(self, discriminator: nn.Module, partial_classes_index: Optional[List[int]] = None):
self.discriminator = discriminator
self.partial_classes_index = partial_classes_index
def get_importance_weight(self, feature):
"""
Get importance weights for each instance.
Args:
feature (tensor): feature from source domain, in shape :math:`(N, F)`
Returns:
instance weight in shape :math:`(N, 1)`
"""
weight = 1. - self.discriminator(feature)
weight = weight / (weight.mean() + 1e-5)
weight = weight.detach()
return weight
def get_partial_classes_weight(self, weights: torch.Tensor, labels: torch.Tensor):
"""
Get class weight averaged on the partial classes and non-partial classes respectively.
Args:
weights (tensor): instance weight in shape :math:`(N, 1)`
labels (tensor): ground truth labels in shape :math:`(N, 1)`
.. warning::
This function is just for debugging, since in real-world dataset, we have no access to the index of \
partial classes and this function will throw an error when `partial_classes_index` is None.
"""
assert self.partial_classes_index is not None
weights = weights.squeeze()
is_partial = torch.Tensor([label in self.partial_classes_index for label in labels]).to(weights.device)
if is_partial.sum() > 0:
partial_classes_weight = (weights * is_partial).sum() / is_partial.sum()
else:
partial_classes_weight = torch.tensor(0)
not_partial = 1. - is_partial
if not_partial.sum() > 0:
not_partial_classes_weight = (weights * not_partial).sum() / not_partial.sum()
else:
not_partial_classes_weight = torch.tensor(0)
return partial_classes_weight, not_partial_classes_weight
class ImageClassifier(ClassifierBase):
r"""The Image Classifier for `Importance Weighted Adversarial Nets for Partial Domain Adaptation `_
"""
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
================================================
FILE: tllib/reweight/pada.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, List, Tuple
from torch.utils.data.dataloader import DataLoader
import torch.nn as nn
import torch
import torch.nn.functional as F
class AutomaticUpdateClassWeightModule(object):
r"""
Calculating class weight based on the output of classifier. See ``ClassWeightModule`` about the details of the calculation.
Every N iterations, the class weight is updated automatically.
Args:
update_steps (int): N, the number of iterations to update class weight.
data_loader (torch.utils.data.DataLoader): The data loader from which we can collect classification outputs.
classifier (torch.nn.Module): Classifier.
num_classes (int): Number of classes.
device (torch.device): The device to run classifier.
temperature (float, optional): T, temperature in ClassWeightModule. Default: 0.1
partial_classes_index (list[int], optional): The index of partial classes. Note that this parameter is \
just for debugging, since in real-world dataset, we have no access to the index of partial classes. \
Default: None.
Examples::
>>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...)
>>> num_iterations = 10000
>>> for _ in range(num_iterations):
>>> class_weight_module.step()
>>> # weight for F.cross_entropy
>>> w_c = class_weight_module.get_class_weight_for_cross_entropy_loss()
>>> # weight for tllib.alignment.dann.DomainAdversarialLoss
>>> w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss()
"""
def __init__(self, update_steps: int, data_loader: DataLoader,
classifier: nn.Module, num_classes: int,
device: torch.device, temperature: Optional[float] = 0.1,
partial_classes_index: Optional[List[int]] = None):
self.update_steps = update_steps
self.data_loader = data_loader
self.classifier = classifier
self.device = device
self.class_weight_module = ClassWeightModule(temperature)
self.class_weight = torch.ones(num_classes).to(device)
self.num_steps = 0
self.partial_classes_index = partial_classes_index
if partial_classes_index is not None:
self.non_partial_classes_index = [c for c in range(num_classes) if c not in partial_classes_index]
def step(self):
self.num_steps += 1
if self.num_steps % self.update_steps == 0:
all_outputs = collect_classification_results(self.data_loader, self.classifier, self.device)
self.class_weight = self.class_weight_module(all_outputs)
def get_class_weight_for_cross_entropy_loss(self):
"""
Outputs: weight for F.cross_entropy
Shape: :math:`(C, )` where C means the number of classes.
"""
return self.class_weight
def get_class_weight_for_adversarial_loss(self, source_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Outputs:
- w_s: source weight for :py:class:`~tllib.alignment.dann.DomainAdversarialLoss`
- w_t: target weight for :py:class:`~tllib.alignment.dann.DomainAdversarialLoss`
Shape:
- w_s: :math:`(minibatch, )`
- w_t: :math:`(minibatch, )`
"""
class_weight_adv_source = self.class_weight[source_labels]
class_weight_adv_target = torch.ones_like(class_weight_adv_source) * class_weight_adv_source.mean()
return class_weight_adv_source, class_weight_adv_target
def get_partial_classes_weight(self):
"""
Get class weight averaged on the partial classes and non-partial classes respectively.
.. warning::
This function is just for debugging, since in real-world dataset, we have no access to the index of \
partial classes and this function will throw an error when `partial_classes_index` is None.
"""
assert self.partial_classes_index is not None
return torch.mean(self.class_weight[self.partial_classes_index]), torch.mean(
self.class_weight[self.non_partial_classes_index])
class ClassWeightModule(nn.Module):
r"""
Calculating class weight based on the output of classifier.
Introduced by `Partial Adversarial Domain Adaptation (ECCV 2018) `_
Given classification logits outputs :math:`\{\hat{y}_i\}_{i=1}^n`, where :math:`n` is the dataset size,
the weight indicating the contribution of each class to the training can be calculated as
follows
.. math::
\mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}\text{softmax}( \hat{y}_i / T),
where :math:`\mathcal{\gamma}` is a :math:`|\mathcal{C}|`-dimensional weight vector quantifying the contribution
of each class and T is a hyper-parameters called temperature.
In practice, it's possible that some of the weights are very small, thus, we normalize weight :math:`\mathcal{\gamma}`
by dividing its largest element, i.e. :math:`\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})`
Args:
temperature (float, optional): hyper-parameters :math:`T`. Default: 0.1
Shape:
- Inputs: (minibatch, :math:`|\mathcal{C}|`)
- Outputs: (:math:`|\mathcal{C}|`,)
"""
def __init__(self, temperature: Optional[float] = 0.1):
super(ClassWeightModule, self).__init__()
self.temperature = temperature
def forward(self, outputs: torch.Tensor):
outputs.detach_()
softmax_outputs = F.softmax(outputs / self.temperature, dim=1)
class_weight = torch.mean(softmax_outputs, dim=0)
class_weight = class_weight / torch.max(class_weight)
class_weight = class_weight.view(-1)
return class_weight
def collect_classification_results(data_loader: DataLoader, classifier: nn.Module,
device: torch.device) -> torch.Tensor:
"""
Fetch data from `data_loader`, and then use `classifier` to collect classification results
Args:
data_loader (torch.utils.data.DataLoader): Data loader.
classifier (torch.nn.Module): A classifier.
device (torch.device)
Returns:
Classification results in shape (len(data_loader), :math:`|\mathcal{C}|`).
"""
training = classifier.training
classifier.eval()
all_outputs = []
with torch.no_grad():
for i, (images, target) in enumerate(data_loader):
images = images.to(device)
output = classifier(images)
all_outputs.append(output)
classifier.train(training)
return torch.cat(all_outputs, dim=0)
================================================
FILE: tllib/self_training/__init__.py
================================================
================================================
FILE: tllib/self_training/cc_loss.py
================================================
"""
@author: Ying Jin
@contact: sherryying003@gmail.com
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.classifier import Classifier as ClassifierBase
from ..modules.entropy import entropy
__all__ = ['CCConsistency']
class CCConsistency(nn.Module):
r"""
CC Loss attach class confusion consistency to MCC.
Args:
temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if
temperature is 1.0.
thr (float): The confidence threshold.
.. note::
Make sure that temperature is larger than 0. Confidence threshold is larger than 0, smaller than 1.0.
Inputs: g_t
- g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`
- g_t_strong (tensor): unnormalized classifier predictions on target domain, with strong data augmentation, :math:`g^t_{strong}`
Shape:
- g_t, g_t_strong: :math:`(minibatch, C)` where C means the number of classes.
- Output: scalar.
Examples::
>>> temperature = 2.0
>>> loss = CCConsistency(temperature)
>>> # logits output from target domain
>>> g_t = torch.randn(batch_size, num_classes)
>>> g_t_strong = torch.randn(batch_size, num_classes)
>>> output = loss(g_t, g_t_strong)
"""
def __init__(self, temperature: float, thr=0.7):
super(CCConsistency, self).__init__()
self.temperature = temperature
self.thr = thr
def forward(self, logits: torch.Tensor, logits_strong: torch.Tensor) -> torch.Tensor:
batch_size, num_classes = logits.shape
logits = logits.detach()
prediction_thr = F.softmax(logits / self.temperature, dim=1)
max_probs, max_idx = torch.max(prediction_thr, dim=-1)
mask_binary = max_probs.ge(self.thr) ### 0.7 for DomainNet, 0.95 for other datasets
mask = mask_binary.float().detach()
if mask.sum() == 0:
return 0, 0
else:
logits = logits[mask_binary]
logits_strong = logits_strong[mask_binary]
predictions = F.softmax(logits / self.temperature, dim=1) # batch_size x num_classes
entropy_weight = entropy(predictions).detach()
entropy_weight = 1 + torch.exp(-entropy_weight)
entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1) # batch_size x 1
class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes
class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1)
predictions_stong = F.softmax(logits_strong / self.temperature, dim=1)
entropy_weight_strong = entropy(predictions_stong).detach()
entropy_weight_strong = 1 + torch.exp(-entropy_weight_strong)
entropy_weight_strong = (batch_size * entropy_weight_strong / torch.sum(entropy_weight_strong)).unsqueeze(dim=1) # batch_size x 1
class_confusion_matrix_strong = torch.mm((predictions_stong * entropy_weight_strong).transpose(1, 0), predictions_stong) # num_classes x num_classes
class_confusion_matrix_strong = class_confusion_matrix_strong / torch.sum(class_confusion_matrix_strong, dim=1)
consistency_loss = ((class_confusion_matrix - class_confusion_matrix_strong) ** 2).sum() / num_classes * mask.sum() / batch_size
#mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes
return consistency_loss, mask.sum()/batch_size
================================================
FILE: tllib/self_training/dst.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.grl import WarmStartGradientReverseLayer
from tllib.modules.classifier import Classifier
class ImageClassifier(Classifier):
r"""
Classifier with non-linear pseudo head :math:`h_{\text{pseudo}}` and worst-case estimation head
:math:`h_{\text{worst}}` from `Debiased Self-Training for Semi-Supervised Learning `_.
Both heads are directly connected to the feature extractor :math:`\psi`. We implement end-to-end adversarial
training procedure between :math:`\psi` and :math:`h_{\text{worst}}` by introducing a gradient reverse layer.
Note that both heads can be safely discarded during inference, and thus will introduce no inference cost.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_classes (int): Number of classes
bottleneck_dim (int, optional): Feature dimension of the bottleneck layer.
width (int, optional): Hidden dimension of the non-linear pseudo head and worst-case estimation head.
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
- outputs: predictions of the main head :math:`h`
- outputs_adv: predictions of the worst-case estimation head :math:`h_{\text{worst}}`
- outputs_pseudo: predictions of the pseudo head :math:`h_{\text{pseudo}}`
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- outputs, outputs_adv, outputs_pseudo: (minibatch, `num_classes`)
"""
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, width=2048, **kwargs):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
bottleneck[0].weight.data.normal_(0, 0.005)
bottleneck[0].bias.data.fill_(0.1)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
self.pseudo_head = nn.Sequential(
nn.Linear(self.features_dim, width),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(width, self.num_classes)
)
self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False)
self.adv_head = nn.Sequential(
nn.Linear(self.features_dim, width),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(width, self.num_classes)
)
def forward(self, x: torch.Tensor):
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
f_adv = self.grl_layer(f)
outputs_adv = self.adv_head(f_adv)
outputs = self.head(f)
outputs_pseudo = self.pseudo_head(f)
if self.training:
return outputs, outputs_adv, outputs_pseudo
else:
return outputs
def get_parameters(self, base_lr=1.0):
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
{"params": self.pseudo_head.parameters(), "lr": 1.0 * base_lr},
{"params": self.adv_head.parameters(), "lr": 1.0 * base_lr}
]
return params
def step(self):
self.grl_layer.step()
def shift_log(x, offset=1e-6):
"""
First shift, then calculate log for numerical stability.
"""
return torch.log(torch.clamp(x + offset, max=1.))
class WorstCaseEstimationLoss(nn.Module):
r"""
Worst-case Estimation loss from `Debiased Self-Training for Semi-Supervised Learning `_
that forces the worst possible head :math:`h_{\text{worst}}` to predict correctly on all labeled samples
:math:`\mathcal{L}` while making as many mistakes as possible on unlabeled data :math:`\mathcal{U}`. In the
classification task, it is defined as:
.. math::
loss(\mathcal{L}, \mathcal{U}) =
\eta' \mathbb{E}_{y^l, y_{adv}^l \sim\hat{\mathcal{L}}} -\log\left(\frac{\exp(y_{adv}^l[h_{y^l}])}{\sum_j \exp(y_{adv}^l[j])}\right) +
\mathbb{E}_{y^u, y_{adv}^u \sim\hat{\mathcal{U}}} -\log\left(1-\frac{\exp(y_{adv}^u[h_{y^u}])}{\sum_j \exp(y_{adv}^u[j])}\right),
where :math:`y^l` and :math:`y^u` are logits output by the main head :math:`h` on labeled data and unlabeled data,
respectively. :math:`y_{adv}^l` and :math:`y_{adv}^u` are logits output by the worst-case estimation
head :math:`h_{\text{worst}}`. :math:`h_y` refers to the predicted label when the logits output is :math:`y`.
Args:
eta_prime (float): the trade-off hyper parameter :math:`\eta'`.
Inputs:
- y_l: logits output :math:`y^l` by the main head on labeled data
- y_l_adv: logits output :math:`y^l_{adv}` by the worst-case estimation head on labeled data
- y_u: logits output :math:`y^u` by the main head on unlabeled data
- y_u_adv: logits output :math:`y^u_{adv}` by the worst-case estimation head on unlabeled data
Shape:
- Inputs: :math:`(minibatch, C)` where C denotes the number of classes.
- Output: scalar.
"""
def __init__(self, eta_prime):
super(WorstCaseEstimationLoss, self).__init__()
self.eta_prime = eta_prime
def forward(self, y_l, y_l_adv, y_u, y_u_adv):
_, prediction_l = y_l.max(dim=1)
loss_l = self.eta_prime * F.cross_entropy(y_l_adv, prediction_l)
_, prediction_u = y_u.max(dim=1)
loss_u = F.nll_loss(shift_log(1. - F.softmax(y_u_adv, dim=1)), prediction_u)
return loss_l + loss_u
================================================
FILE: tllib/self_training/flexmatch.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from collections import Counter
import torch
class DynamicThresholdingModule(object):
r"""
Dynamic thresholding module from `FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling
`_. At time :math:`t`, for each category :math:`c`,
the learning status :math:`\sigma_t(c)` is estimated by the number of samples whose predictions fall into this class
and above a threshold (e.g. 0.95). Then, FlexMatch normalizes :math:`\sigma_t(c)` to make its range between 0 and 1
.. math::
\beta_t(c) = \frac{\sigma_t(c)}{\underset{c'}{\text{max}}~\sigma_t(c')}.
The dynamic threshold is formulated as
.. math::
\mathcal{T}_t(c) = \mathcal{M}(\beta_t(c)) \cdot \tau,
where \tau denotes the pre-defined threshold (e.g. 0.95), :math:`\mathcal{M}` denotes a (possibly non-linear)
mapping function.
Args:
threshold (float): The pre-defined confidence threshold
warmup (bool): Whether perform threshold warm-up. If True, the number of unlabeled data that have not been
used will be considered when normalizing :math:`\sigma_t(c)`
mapping_func (callable): An increasing mapping function. For example, this function can be (1) concave
:math:`\mathcal{M}(x)=\text{ln}(x+1)/\text{ln}2`, (2) linear :math:`\mathcal{M}(x)=x`,
and (3) convex :math:`\mathcal{M}(x)=2/2-x`
num_classes (int): Number of classes
n_unlabeled_samples (int): Size of the unlabeled dataset
device (torch.device): Device
"""
def __init__(self, threshold, warmup, mapping_func, num_classes, n_unlabeled_samples, device):
self.threshold = threshold
self.warmup = warmup
self.mapping_func = mapping_func
self.num_classes = num_classes
self.n_unlabeled_samples = n_unlabeled_samples
self.net_outputs = torch.zeros(n_unlabeled_samples, dtype=torch.long).to(device)
self.net_outputs.fill_(-1)
self.device = device
def get_threshold(self, pseudo_labels):
"""Calculate and return dynamic threshold"""
pseudo_counter = Counter(self.net_outputs.tolist())
if max(pseudo_counter.values()) == self.n_unlabeled_samples:
# In the early stage of training, the network does not output pseudo labels with high confidence.
# In this case, the learning status of all categories is simply zero.
status = torch.zeros(self.num_classes).to(self.device)
else:
if not self.warmup and -1 in pseudo_counter.keys():
pseudo_counter.pop(-1)
max_num = max(pseudo_counter.values())
# estimate learning status
status = [
pseudo_counter[c] / max_num for c in range(self.num_classes)
]
status = torch.FloatTensor(status).to(self.device)
# calculate dynamic threshold
dynamic_threshold = self.threshold * self.mapping_func(status[pseudo_labels])
return dynamic_threshold
def update(self, idxes, selected_mask, pseudo_labels):
"""Update the learning status
Args:
idxes (tensor): Indexes of corresponding samples
selected_mask (tensor): A binary mask, a value of 1 indicates the prediction for this sample will be updated
pseudo_labels (tensor): Network predictions
"""
if idxes[selected_mask == 1].nelement() != 0:
self.net_outputs[idxes[selected_mask == 1]] = pseudo_labels[selected_mask == 1]
================================================
FILE: tllib/self_training/mcc.py
================================================
"""
@author: Ying Jin
@contact: sherryying003@gmail.com
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.classifier import Classifier as ClassifierBase
from ..modules.entropy import entropy
__all__ = ['MinimumClassConfusionLoss', 'ImageClassifier']
class MinimumClassConfusionLoss(nn.Module):
r"""
Minimum Class Confusion loss minimizes the class confusion in the target predictions.
You can see more details in `Minimum Class Confusion for Versatile Domain Adaptation (ECCV 2020) `_
Args:
temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if
temperature is 1.0.
.. note::
Make sure that temperature is larger than 0.
Inputs: g_t
- g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t`
Shape:
- g_t: :math:`(minibatch, C)` where C means the number of classes.
- Output: scalar.
Examples::
>>> temperature = 2.0
>>> loss = MinimumClassConfusionLoss(temperature)
>>> # logits output from target domain
>>> g_t = torch.randn(batch_size, num_classes)
>>> output = loss(g_t)
MCC can also serve as a regularizer for existing methods.
Examples::
>>> from tllib.modules.domain_discriminator import DomainDiscriminator
>>> num_classes = 2
>>> feature_dim = 1024
>>> batch_size = 10
>>> temperature = 2.0
>>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024)
>>> cdan_loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean')
>>> mcc_loss = MinimumClassConfusionLoss(temperature)
>>> # features from source domain and target domain
>>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim)
>>> # logits output from source domain adn target domain
>>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes)
>>> total_loss = cdan_loss(g_s, f_s, g_t, f_t) + mcc_loss(g_t)
"""
def __init__(self, temperature: float):
super(MinimumClassConfusionLoss, self).__init__()
self.temperature = temperature
def forward(self, logits: torch.Tensor) -> torch.Tensor:
batch_size, num_classes = logits.shape
predictions = F.softmax(logits / self.temperature, dim=1) # batch_size x num_classes
entropy_weight = entropy(predictions).detach()
entropy_weight = 1 + torch.exp(-entropy_weight)
entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1) # batch_size x 1
class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes
class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1)
mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes
return mcc_loss
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
================================================
FILE: tllib/self_training/mean_teacher.py
================================================
import copy
from typing import Optional
import torch
def set_requires_grad(net, requires_grad=False):
"""
Set requires_grad=False for all the parameters to avoid unnecessary computations
"""
for param in net.parameters():
param.requires_grad = requires_grad
class EMATeacher(object):
r"""
Exponential moving average model from `Mean teachers are better role models: Weight-averaged consistency targets
improve semi-supervised deep learning results (NIPS 2017) `_
We use :math:`\theta_t'` to denote parameters of the teacher model at training step t, use :math:`\theta_t` to
denote parameters of the student model at training step t. Given decay factor :math:`\alpha`,
we update the teacher model in an exponential moving average manner
.. math::
\theta_t'=\alpha \theta_{t-1}' + (1-\alpha)\theta_t
Args:
model (torch.nn.Module): the student model
alpha (float): decay factor for EMA.
Inputs:
x (tensor): input tensor
Examples::
>>> classifier = ImageClassifier(backbone, num_classes=31, bottleneck_dim=256).to(device)
>>> # initialize teacher model
>>> teacher = EMATeacher(classifier, 0.9)
>>> num_iterations = 1000
>>> for _ in range(num_iterations):
>>> # x denotes input of one mini-batch
>>> # you can get teacher model's output by teacher(x)
>>> y_teacher = teacher(x)
>>> # when you want to update teacher, you should call teacher.update()
>>> teacher.update()
"""
def __init__(self, model, alpha):
self.model = model
self.alpha = alpha
self.teacher = copy.deepcopy(model)
set_requires_grad(self.teacher, False)
def set_alpha(self, alpha: float):
assert alpha >= 0
self.alpha = alpha
def update(self):
for teacher_param, param in zip(self.teacher.parameters(), self.model.parameters()):
teacher_param.data = self.alpha * teacher_param + (1 - self.alpha) * param
def __call__(self, x: torch.Tensor):
return self.teacher(x)
def train(self, mode: Optional[bool] = True):
self.teacher.train(mode)
def eval(self):
self.train(False)
def state_dict(self):
return self.teacher.state_dict()
def load_state_dict(self, state_dict):
self.teacher.load_state_dict(state_dict)
@property
def module(self):
return self.teacher.module
def update_bn(model, ema_model):
"""
Replace batch normalization statistics of the teacher model with that ot the student model
"""
for m2, m1 in zip(ema_model.named_modules(), model.named_modules()):
if ('bn' in m2[0]) and ('bn' in m1[0]):
bn2, bn1 = m2[1].state_dict(), m1[1].state_dict()
bn2['running_mean'].data.copy_(bn1['running_mean'].data)
bn2['running_var'].data.copy_(bn1['running_var'].data)
bn2['num_batches_tracked'].data.copy_(bn1['num_batches_tracked'].data)
================================================
FILE: tllib/self_training/pi_model.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from typing import Callable, Optional
import numpy as np
import torch
from torch import nn as nn
def sigmoid_warm_up(current_epoch, warm_up_epochs: int):
"""Exponential warm up function from `Temporal Ensembling for Semi-Supervised Learning
(ICLR 2017) `_.
"""
assert warm_up_epochs >= 0
if warm_up_epochs == 0:
return 1.0
else:
current_epoch = np.clip(current_epoch, 0.0, warm_up_epochs)
process = 1.0 - current_epoch / warm_up_epochs
return float(np.exp(-5.0 * process * process))
class ConsistencyLoss(nn.Module):
r"""
Consistency loss between two predictions. Given distance measure :math:`D`, predictions :math:`p_1, p_2`,
binary mask :math:`mask`, the consistency loss is
.. math::
D(p_1, p_2) * mask
Args:
distance_measure (callable): Distance measure function.
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Inputs:
- p1: the first prediction
- p2: the second prediction
- mask: binary mask. Default: 1. (use all samples when calculating loss)
Shape:
- p1, p2: :math:`(N, C)` where C means the number of classes.
- mask: :math:`(N, )` where N means mini-batch size.
"""
def __init__(self, distance_measure: Callable, reduction: Optional[str] = 'mean'):
super(ConsistencyLoss, self).__init__()
self.distance_measure = distance_measure
self.reduction = reduction
def forward(self, p1: torch.Tensor, p2: torch.Tensor, mask=1.):
cons_loss = self.distance_measure(p1, p2)
cons_loss = cons_loss * mask
if self.reduction == 'mean':
return cons_loss.mean()
elif self.reduction == 'sum':
return cons_loss.sum()
else:
return cons_loss
class L2ConsistencyLoss(ConsistencyLoss):
r"""
L2 consistency loss. Given two predictions :math:`p_1, p_2` and binary mask :math:`mask`, the
L2 consistency loss is
.. math::
\text{MSELoss}(p_1, p_2) * mask
"""
def __init__(self, reduction: Optional[str] = 'mean'):
def l2_distance(p1: torch.Tensor, p2: torch.Tensor):
return ((p1 - p2) ** 2).sum(dim=1)
super(L2ConsistencyLoss, self).__init__(l2_distance, reduction)
================================================
FILE: tllib/self_training/pseudo_label.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch.nn as nn
import torch.nn.functional as F
class ConfidenceBasedSelfTrainingLoss(nn.Module):
"""
Self training loss that adopts confidence threshold to select reliable pseudo labels from
`Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks (ICML 2013)
`_.
Args:
threshold (float): Confidence threshold.
Inputs:
- y: unnormalized classifier predictions.
- y_target: unnormalized classifier predictions which will used for generating pseudo labels.
Returns:
A tuple, including
- self_training_loss: self training loss with pseudo labels.
- mask: binary mask that indicates which samples are retained (whose confidence is above the threshold).
- pseudo_labels: generated pseudo labels.
Shape:
- y, y_target: :math:`(minibatch, C)` where C means the number of classes.
- self_training_loss: scalar.
- mask, pseudo_labels :math:`(minibatch, )`.
"""
def __init__(self, threshold: float):
super(ConfidenceBasedSelfTrainingLoss, self).__init__()
self.threshold = threshold
def forward(self, y, y_target):
confidence, pseudo_labels = F.softmax(y_target.detach(), dim=1).max(dim=1)
mask = (confidence > self.threshold).float()
self_training_loss = (F.cross_entropy(y, pseudo_labels, reduction='none') * mask).mean()
return self_training_loss, mask, pseudo_labels
================================================
FILE: tllib/self_training/self_ensemble.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from tllib.modules.classifier import Classifier as ClassifierBase
class ClassBalanceLoss(nn.Module):
r"""
Class balance loss that penalises the network for making predictions that exhibit large class imbalance.
Given predictions :math:`p` with dimension :math:`(N, C)`, we first calculate
the mini-batch mean per-class probability :math:`p_{mean}` with dimension :math:`(C, )`, where
.. math::
p_{mean}^j = \frac{1}{N} \sum_{i=1}^N p_i^j
Then we calculate binary cross entropy loss between :math:`p_{mean}` and uniform probability vector :math:`u` with
the same dimension where :math:`u^j` = :math:`\frac{1}{C}`
.. math::
loss = \text{BCELoss}(p_{mean}, u)
Args:
num_classes (int): Number of classes
Inputs:
- p (tensor): predictions from classifier
Shape:
- p: :math:`(N, C)` where C means the number of classes.
"""
def __init__(self, num_classes):
super(ClassBalanceLoss, self).__init__()
self.uniform_distribution = torch.ones(num_classes) / num_classes
def forward(self, p: torch.Tensor):
return F.binary_cross_entropy(p.mean(dim=0), self.uniform_distribution.to(p.device))
class ImageClassifier(ClassifierBase):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs):
bottleneck = nn.Sequential(
# nn.AdaptiveAvgPool2d(output_size=(1, 1)),
# nn.Flatten(),
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU()
)
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
================================================
FILE: tllib/self_training/self_tuning.py
================================================
"""
Adapted from https://github.com/thuml/Self-Tuning/tree/master
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
import torch.nn as nn
from torch.nn.functional import normalize
from tllib.modules.classifier import Classifier as ClassifierBase
class Classifier(ClassifierBase):
"""Classifier class for Self-Tuning.
Args:
backbone (torch.nn.Module): Any backbone to extract 2-d features from data
num_classes (int): Number of classes.
projection_dim (int, optional): Dimension of the projector head. Default: 128
finetune (bool): Whether finetune the classifier or train from scratch. Default: True
Inputs:
- x (tensor): input data fed to `backbone`
Outputs:
In the training mode,
- h: projections
- y: classifier's predictions
In the eval mode,
- y: classifier's predictions
Shape:
- Inputs: (minibatch, *) where * means, any number of additional dimensions
- y: (minibatch, `num_classes`)
- h: (minibatch, `projection_dim`)
"""
def __init__(self, backbone: nn.Module, num_classes: int, projection_dim=1024, bottleneck_dim=1024, finetune=True,
pool_layer=None):
bottleneck = nn.Sequential(
nn.Linear(backbone.out_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
nn.ReLU(),
nn.Dropout(0.5)
)
bottleneck[0].weight.data.normal_(0, 0.005)
bottleneck[0].bias.data.fill_(0.1)
head = nn.Linear(1024, num_classes)
super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune,
pool_layer=pool_layer, bottleneck=bottleneck, bottleneck_dim=bottleneck_dim)
self.projector = nn.Linear(1024, projection_dim)
self.projection_dim = projection_dim
def forward(self, x: torch.Tensor):
f = self.pool_layer(self.backbone(x))
f = self.bottleneck(f)
# projections
h = self.projector(f)
h = normalize(h, dim=1)
# predictions
predictions = self.head(f)
if self.training:
return h, predictions
else:
return predictions
def get_parameters(self, base_lr=1.0):
params = [
{"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
{"params": self.projector.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr},
]
return params
class SelfTuning(nn.Module):
r"""Self-Tuning module in `Self-Tuning for Data-Efficient Deep Learning (self-tuning, ICML 2021)
`_.
Args:
encoder_q (Classifier): Query encoder.
encoder_k (Classifier): Key encoder.
num_classes (int): Number of classes
K (int): Queue size. Default: 32
m (float): Momentum coefficient. Default: 0.999
T (float): Temperature. Default: 0.07
Inputs:
- im_q (tensor): input data fed to `encoder_q`
- im_k (tensor): input data fed to `encoder_k`
- labels (tensor): classification labels of input data
Outputs: pgc_logits, pgc_labels, y_q
- pgc_logits: projector's predictions on both positive and negative samples
- pgc_labels: contrastive labels
- y_q: query classifier's predictions
Shape:
- im_q, im_k: (minibatch, *) where * means, any number of additional dimensions
- labels: (minibatch, )
- y_q: (minibatch, `num_classes`)
- pgc_logits: (minibatch, 1 + `num_classes` :math:`\times` `K`, `projection_dim`)
- pgc_labels: (minibatch, 1 + `num_classes` :math:`\times` `K`)
"""
def __init__(self, encoder_q, encoder_k, num_classes, K=32, m=0.999, T=0.07):
super(SelfTuning, self).__init__()
self.K = K
self.m = m
self.T = T
self.num_classes = num_classes
# create the encoders
# num_classes is the output fc dimension
self.encoder_q = encoder_q
self.encoder_k = encoder_k
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data.copy_(param_q.data)
param_k.requires_grad = False
# create the queue
self.register_buffer("queue_list", torch.randn(encoder_q.projection_dim, K * self.num_classes))
self.queue_list = normalize(self.queue_list, dim=0)
self.register_buffer("queue_ptr", torch.zeros(self.num_classes, dtype=torch.long))
@torch.no_grad()
def _momentum_update_key_encoder(self):
"""
Momentum update of the key encoder
"""
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
@torch.no_grad()
def _dequeue_and_enqueue(self, h, label):
# gather keys before updating queue
batch_size = h.shape[0]
ptr = int(self.queue_ptr[label])
real_ptr = ptr + label * self.K
# replace the keys at ptr (dequeue and enqueue)
self.queue_list[:, real_ptr:real_ptr + batch_size] = h.T
# move pointer
ptr = (ptr + batch_size) % self.K
self.queue_ptr[label] = ptr
def forward(self, im_q, im_k, labels):
batch_size = im_q.size(0)
device = im_q.device
# compute query features
h_q, y_q = self.encoder_q(im_q) # queries: h_q (N x projection_dim)
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update_key_encoder() # update the key encoder
h_k, _ = self.encoder_k(im_k) # keys: h_k (N x projection_dim)
# compute logits
# positive logits: Nx1
logits_pos = torch.einsum('nl,nl->n', [h_q, h_k]).unsqueeze(-1) # Einstein sum is more intuitive
# cur_queue_list: queue_size * class_num
cur_queue_list = self.queue_list.clone().detach()
logits_neg_list = torch.Tensor([]).to(device)
logits_pos_list = torch.Tensor([]).to(device)
for i in range(batch_size):
neg_sample = torch.cat([cur_queue_list[:, 0:labels[i] * self.K],
cur_queue_list[:, (labels[i] + 1) * self.K:]],
dim=1)
pos_sample = cur_queue_list[:, labels[i] * self.K: (labels[i] + 1) * self.K]
ith_neg = torch.einsum('nl,lk->nk', [h_q[i:i + 1], neg_sample])
ith_pos = torch.einsum('nl,lk->nk', [h_q[i:i + 1], pos_sample])
logits_neg_list = torch.cat((logits_neg_list, ith_neg), dim=0)
logits_pos_list = torch.cat((logits_pos_list, ith_pos), dim=0)
self._dequeue_and_enqueue(h_k[i:i + 1], labels[i])
# logits: 1 + queue_size + queue_size * (class_num - 1)
pgc_logits = torch.cat([logits_pos, logits_pos_list, logits_neg_list], dim=1)
pgc_logits = nn.LogSoftmax(dim=1)(pgc_logits / self.T)
pgc_labels = torch.zeros([batch_size, 1 + self.K * self.num_classes]).to(device)
pgc_labels[:, 0:self.K + 1].fill_(1.0 / (self.K + 1))
return pgc_logits, pgc_labels, y_q
================================================
FILE: tllib/self_training/uda.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch.nn as nn
import torch.nn.functional as F
class StrongWeakConsistencyLoss(nn.Module):
"""
Consistency loss between strong and weak augmented samples from `Unsupervised Data Augmentation for
Consistency Training (NIPS 2020) `_.
Args:
threshold (float): Confidence threshold.
temperature (float): Temperature.
Inputs:
- y_strong: unnormalized classifier predictions on strong augmented samples.
- y: unnormalized classifier predictions on weak augmented samples.
Shape:
- y, y_strong: :math:`(minibatch, C)` where C means the number of classes.
- Output: scalar.
"""
def __init__(self, threshold: float, temperature: float):
super(StrongWeakConsistencyLoss, self).__init__()
self.threshold = threshold
self.temperature = temperature
def forward(self, y_strong, y):
confidence, _ = F.softmax(y.detach(), dim=1).max(dim=1)
mask = (confidence > self.threshold).float()
log_prob = F.log_softmax(y_strong / self.temperature, dim=1)
con_loss = (F.kl_div(log_prob, F.softmax(y.detach(), dim=1), reduction='none').sum(dim=1))
con_loss = (con_loss * mask).sum() / max(mask.sum(), 1)
return con_loss
================================================
FILE: tllib/translation/__init__.py
================================================
================================================
FILE: tllib/translation/cycada.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
from torch import Tensor
class SemanticConsistency(nn.Module):
"""
Semantic consistency loss is introduced by
`CyCADA: Cycle-Consistent Adversarial Domain Adaptation (ICML 2018) `_
This helps to prevent label flipping during image translation.
Args:
ignore_index (tuple, optional): Specifies target values that are ignored
and do not contribute to the input gradient. When :attr:`size_average` is
``True``, the loss is averaged over non-ignored targets. Default: ().
reduction (string, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
be applied, ``'mean'``: the weighted mean of the output is taken,
``'sum'``: the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in
the meantime, specifying either of those two args will override
:attr:`reduction`. Default: ``'mean'``
Shape:
- Input: :math:`(N, C)` where `C = number of classes`, or
:math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1`
in the case of `K`-dimensional loss.
- Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of
K-dimensional loss.
- Output: scalar.
If :attr:`reduction` is ``'none'``, then the same size as the target:
:math:`(N)`, or
:math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case
of K-dimensional loss.
Examples::
>>> loss = SemanticConsistency()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
"""
def __init__(self, ignore_index=(), reduction='mean'):
super(SemanticConsistency, self).__init__()
self.ignore_index = ignore_index
self.loss = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction)
def forward(self, input: Tensor, target: Tensor) -> Tensor:
for class_idx in self.ignore_index:
target[target == class_idx] = -1
return self.loss(input, target)
================================================
FILE: tllib/translation/cyclegan/__init__.py
================================================
from . import discriminator
from . import generator
from . import loss
from . import transform
from .discriminator import *
from .generator import *
from .loss import *
from .transform import *
================================================
FILE: tllib/translation/cyclegan/discriminator.py
================================================
"""
Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
from torch.nn import init
import functools
from .util import get_norm_layer, init_weights
class NLayerDiscriminator(nn.Module):
"""Construct a PatchGAN discriminator
Args:
input_nc (int): the number of channels in input images.
ndf (int): the number of filters in the last conv layer. Default: 64
n_layers (int): the number of conv layers in the discriminator. Default: 3
norm_layer (torch.nn.Module): normalization layer. Default: :class:`nn.BatchNorm2d`
"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
class PixelDiscriminator(nn.Module):
"""Construct a 1x1 PatchGAN discriminator (pixelGAN)
Args:
input_nc (int): the number of channels in input images.
ndf (int): the number of filters in the last conv layer. Default: 64
norm_layer (torch.nn.Module): normalization layer. Default: :class:`nn.BatchNorm2d`
"""
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
super(PixelDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
self.net = nn.Sequential(*self.net)
def forward(self, input):
return self.net(input)
def patch(ndf, input_nc=3, norm='batch', n_layers=3, init_type='normal', init_gain=0.02):
"""
PatchGAN classifier described in the original pix2pix paper.
It can classify whether 70×70 overlapping patches are real or fake.
Such a patch-level discriminator architecture has fewer parameters
than a full-image discriminator and can work on arbitrarily-sized images
in a fully convolutional fashion.
Args:
ndf (int): the number of filters in the first conv layer
input_nc (int): the number of channels in input images. Default: 3
norm (str): the type of normalization layers used in the network. Default: 'batch'
n_layers (int): the number of conv layers in the discriminator. Default: 3
init_type (str): the name of the initialization method. Choices includes: ``normal`` |
``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'
init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02
"""
norm_layer = get_norm_layer(norm_type=norm)
net = NLayerDiscriminator(input_nc, ndf, n_layers=n_layers, norm_layer=norm_layer)
init_weights(net, init_type, init_gain=init_gain)
return net
def pixel(ndf, input_nc=3, norm='batch', init_type='normal', init_gain=0.02):
"""
1x1 PixelGAN discriminator can classify whether a pixel is real or not.
It encourages greater color diversity but has no effect on spatial statistics.
Args:
ndf (int): the number of filters in the first conv layer
input_nc (int): the number of channels in input images. Default: 3
norm (str): the type of normalization layers used in the network. Default: 'batch'
init_type (str): the name of the initialization method. Choices includes: ``normal`` |
``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'
init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02
"""
norm_layer = get_norm_layer(norm_type=norm)
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
init_weights(net, init_type, init_gain=init_gain)
return net
================================================
FILE: tllib/translation/cyclegan/generator.py
================================================
"""
Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
import torch.nn as nn
import functools
from .util import get_norm_layer, init_weights
class ResnetBlock(nn.Module):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Initialize the Resnet block
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
and implement skip connections in function.
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
"""
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
"""Construct a convolutional block.
Args:
dim (int): the number of channels in the conv layer.
padding_type (str): the name of padding layer: reflect | replicate | zero
norm_layer (torch.nn.Module): normalization layer
use_dropout (bool): if use dropout layers.
use_bias (bool): if the conv layer uses bias or not
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
"""
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
"""Forward function (with skip connections)"""
out = x + self.conv_block(x) # add skip connections
return out
class ResnetGenerator(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Args:
input_nc (int): the number of channels in input images
output_nc (int): the number of channels in output images
ngf (int): the number of filters in the last conv layer
norm_layer (torch.nn.Module): normalization layer
use_dropout (bool): if use dropout layers
n_blocks (int): the number of ResNet blocks
padding_type (str): the name of padding layer in conv layers: reflect | replicate | zero
"""
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""Standard forward"""
return self.model(input)
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Args:
input_nc (int): the number of channels in input images
output_nc (int): the number of channels in output images
num_downs (int): the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int): the number of filters in the last conv layer
norm_layer(torch.nn.Module): normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
def forward(self, input):
"""Standard forward"""
return self.model(input)
class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
X -------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet submodule with skip connections.
Args:
outer_nc (int): the number of filters in the outer conv layer
inner_nc (int): the number of filters in the inner conv layer
input_nc (int): the number of channels in input images/features
submodule (UnetSkipConnectionBlock): previously defined submodules
outermost (bool): if this module is the outermost module
innermost (bool): if this module is the innermost module
norm_layer (torch.nn.Module): normalization layer
use_dropout (bool): if use dropout layers.
"""
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)
def resnet_9(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,
init_type='normal', init_gain=0.02):
"""
Resnet-based generator with 9 Resnet blocks.
Args:
ngf (int): the number of filters in the last conv layer
input_nc (int): the number of channels in input images. Default: 3
output_nc (int): the number of channels in output images. Default: 3
norm (str): the type of normalization layers used in the network. Default: 'batch'
use_dropout (bool): whether use dropout. Default: False
init_type (str): the name of the initialization method. Choices includes: ``normal`` |
``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'
init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02
"""
norm_layer = get_norm_layer(norm_type=norm)
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
init_weights(net, init_type, init_gain)
return net
def resnet_6(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,
init_type='normal', init_gain=0.02):
"""
Resnet-based generator with 6 Resnet blocks.
Args:
ngf (int): the number of filters in the last conv layer
input_nc (int): the number of channels in input images. Default: 3
output_nc (int): the number of channels in output images. Default: 3
norm (str): the type of normalization layers used in the network. Default: 'batch'
use_dropout (bool): whether use dropout. Default: False
init_type (str): the name of the initialization method. Choices includes: ``normal`` |
``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'
init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02
"""
norm_layer = get_norm_layer(norm_type=norm)
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
init_weights(net, init_type, init_gain)
return net
def unet_256(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,
init_type='normal', init_gain=0.02):
"""
`U-Net `_ generator for 256x256 input images.
The size of the input image should be a multiple of 256.
Args:
ngf (int): the number of filters in the last conv layer
input_nc (int): the number of channels in input images. Default: 3
output_nc (int): the number of channels in output images. Default: 3
norm (str): the type of normalization layers used in the network. Default: 'batch'
use_dropout (bool): whether use dropout. Default: False
init_type (str): the name of the initialization method. Choices includes: ``normal`` |
``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'
init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02
"""
norm_layer = get_norm_layer(norm_type=norm)
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
init_weights(net, init_type, init_gain)
return net
def unet_128(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,
init_type='normal', init_gain=0.02):
"""
`U-Net `_ generator for 128x128 input images.
The size of the input image should be a multiple of 128.
Args:
ngf (int): the number of filters in the last conv layer
input_nc (int): the number of channels in input images. Default: 3
output_nc (int): the number of channels in output images. Default: 3
norm (str): the type of normalization layers used in the network. Default: 'batch'
use_dropout (bool): whether use dropout. Default: False
init_type (str): the name of the initialization method. Choices includes: ``normal`` |
``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'
init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02
"""
norm_layer = get_norm_layer(norm_type=norm)
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
init_weights(net, init_type, init_gain)
return net
def unet_32(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False,
init_type='normal', init_gain=0.02):
"""
`U-Net `_ generator for 32x32 input images
Args:
ngf (int): the number of filters in the last conv layer
input_nc (int): the number of channels in input images. Default: 3
output_nc (int): the number of channels in output images. Default: 3
norm (str): the type of normalization layers used in the network. Default: 'batch'
use_dropout (bool): whether use dropout. Default: False
init_type (str): the name of the initialization method. Choices includes: ``normal`` |
``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal'
init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02
"""
norm_layer = get_norm_layer(norm_type=norm)
net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
init_weights(net, init_type, init_gain)
return net
================================================
FILE: tllib/translation/cyclegan/loss.py
================================================
"""
Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
import torch
class LeastSquaresGenerativeAdversarialLoss(nn.Module):
"""
Loss for `Least Squares Generative Adversarial Network (LSGAN) `_
Args:
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Inputs:
- prediction (tensor): unnormalized discriminator predictions
- real (bool): if the ground truth label is for real images or fake images. Default: true
.. warning::
Do not use sigmoid as the last layer of Discriminator.
"""
def __init__(self, reduction='mean'):
super(LeastSquaresGenerativeAdversarialLoss, self).__init__()
self.mse_loss = nn.MSELoss(reduction=reduction)
def forward(self, prediction, real=True):
if real:
label = torch.ones_like(prediction)
else:
label = torch.zeros_like(prediction)
return self.mse_loss(prediction, label)
class VanillaGenerativeAdversarialLoss(nn.Module):
"""
Loss for `Vanilla Generative Adversarial Network `_
Args:
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Inputs:
- prediction (tensor): unnormalized discriminator predictions
- real (bool): if the ground truth label is for real images or fake images. Default: true
.. warning::
Do not use sigmoid as the last layer of Discriminator.
"""
def __init__(self, reduction='mean'):
super(VanillaGenerativeAdversarialLoss, self).__init__()
self.bce_loss = nn.BCEWithLogitsLoss(reduction=reduction)
def forward(self, prediction, real=True):
if real:
label = torch.ones_like(prediction)
else:
label = torch.zeros_like(prediction)
return self.bce_loss(prediction, label)
class WassersteinGenerativeAdversarialLoss(nn.Module):
"""
Loss for `Wasserstein Generative Adversarial Network `_
Args:
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
Inputs:
- prediction (tensor): unnormalized discriminator predictions
- real (bool): if the ground truth label is for real images or fake images. Default: true
.. warning::
Do not use sigmoid as the last layer of Discriminator.
"""
def __init__(self, reduction='mean'):
super(WassersteinGenerativeAdversarialLoss, self).__init__()
self.mse_loss = nn.MSELoss(reduction=reduction)
def forward(self, prediction, real=True):
if real:
return -prediction.mean()
else:
return prediction.mean()
================================================
FILE: tllib/translation/cyclegan/transform.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
import torch.nn as nn
import torchvision.transforms as T
from tllib.vision.transforms import Denormalize
class Translation(nn.Module):
"""
Image Translation Transform Module
Args:
generator (torch.nn.Module): An image generator, e.g. :meth:`~tllib.translation.cyclegan.resnet_9_generator`
device (torch.device): device to put the generator. Default: 'cpu'
mean (tuple): the normalized mean for image
std (tuple): the normalized std for image
Input:
- image (PIL.Image): raw image in shape H x W x C
Output:
raw image in shape H x W x 3
"""
def __init__(self, generator, device=torch.device("cpu"), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)):
super(Translation, self).__init__()
self.generator = generator.to(device)
self.device = device
self.pre_process = T.Compose([
T.ToTensor(),
T.Normalize(mean, std)
])
self.post_process = T.Compose([
Denormalize(mean, std),
T.ToPILImage()
])
def forward(self, image):
image = self.pre_process(image.copy()) # C x H x W
image = image.to(self.device)
generated_image = self.generator(image.unsqueeze(dim=0)).squeeze(dim=0).cpu()
return self.post_process(generated_image)
================================================
FILE: tllib/translation/cyclegan/util.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
import functools
import random
import torch
from torch.nn import init
class Identity(nn.Module):
def forward(self, x):
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
def norm_layer(x): return Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Args:
net (torch.nn.Module): network to be initialized
init_type (str): the name of an initialization method. Choices includes: ``normal`` |
``xavier`` | ``kaiming`` | ``orthogonal``
init_gain (float): scaling factor for normal, xavier and orthogonal.
'normal' is used in the original CycleGAN paper. But xavier and kaiming might
work better for some applications.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function
class ImagePool:
"""An image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
Args:
pool_size (int): the size of image buffer, if pool_size=0, no buffer will be created
"""
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_imgs = 0
self.images = []
def query(self, images):
"""Return an image from the pool.
Args:
images (torch.Tensor): the latest generated images from the generator
Returns:
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = torch.cat(return_images, 0) # collect all the images and return
return return_images
def set_requires_grad(net, requires_grad=False):
"""
Set requies_grad=Fasle for all the networks to avoid unnecessary computations
"""
for param in net.parameters():
param.requires_grad = requires_grad
================================================
FILE: tllib/translation/fourier_transform.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import numpy as np
import os
import tqdm
import random
from PIL import Image
from typing import Optional, Sequence
import torch.nn as nn
def low_freq_mutate(amp_src: np.ndarray, amp_trg: np.ndarray, beta: Optional[int] = 1):
"""
Args:
amp_src (numpy.ndarray): amplitude component of the Fourier transform of source image
amp_trg (numpy.ndarray): amplitude component of the Fourier transform of target image
beta (int, optional): the size of the center region to be replace. Default: 1
Returns:
amplitude component of the Fourier transform of source image
whose low-frequency component is replaced by that of the target image.
"""
# Shift the zero-frequency component to the center of the spectrum.
a_src = np.fft.fftshift(amp_src, axes=(-2, -1))
a_trg = np.fft.fftshift(amp_trg, axes=(-2, -1))
# The low-frequency component includes
# the area where the horizontal and vertical distance from the center does not exceed beta
_, h, w = a_src.shape
c_h = np.floor(h / 2.0).astype(int)
c_w = np.floor(w / 2.0).astype(int)
h1 = c_h - beta
h2 = c_h + beta + 1
w1 = c_w - beta
w2 = c_w + beta + 1
# The low-frequency component of source amplitude is replaced by the target amplitude
a_src[:, h1:h2, w1:w2] = a_trg[:, h1:h2, w1:w2]
a_src = np.fft.ifftshift(a_src, axes=(-2, -1))
return a_src
class FourierTransform(nn.Module):
"""
Fourier Transform is introduced by `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) `_
Fourier Transform replace the low frequency component of the amplitude of the source image to that of the target image.
Denote with :math:`M_{β}` a mask, whose value is zero except for the center region:
.. math::
M_{β}(h,w) = \mathbb{1}_{(h, w)\in [-β,β, -β, β]}
Given images :math:`x^s` from source domain and :math:`x^t` from target domain, the source image in the target style is
.. math::
x^{s→t} = \mathcal{F}^{-1}([ M_{β}\circ\mathcal{F}^A(x^t) + (1-M_{β})\circ\mathcal{F}^A(x^s), \mathcal{F}^P(x^s) ])
where :math:`\mathcal{F}^A`, :math:`\mathcal{F}^P` are the amplitude and phase component of the Fourier
Transform :math:`\mathcal{F}` of an RGB image.
Args:
image_list (sequence[str]): A sequence of image list from the target domain.
amplitude_dir (str): Specifies the directory to put the amplitude component of the target image.
beta (int, optional): :math:`β`. Default: 1.
rebuild (bool, optional): whether rebuild the amplitude component of the target image in the given directory.
Inputs:
- image (PIL Image): image from the source domain, :math:`x^t`.
Examples:
>>> from tllib.translation.fourier_transform import FourierTransform
>>> image_list = ["target_image_path1", "target_image_path2"]
>>> amplitude_dir = "path/to/amplitude_dir"
>>> fourier_transform = FourierTransform(image_list, amplitude_dir, beta=1, rebuild=False)
>>> source_image = np.array((256, 256, 3)) # image form source domain
>>> source_image_in_target_style = fourier_transform(source_image)
.. note::
The meaning of :math:`β` is different from that of the origin paper. Experimentally, we found that the size of
the center region in the frequency space should be constant when the image size increases. Thus we make the size
of the center region independent of the image size. A recommended value for :math:`β` is 1.
.. note::
The image structure of the source domain and target domain should be as similar as possible,
thus for segemntation tasks, FourierTransform should be used before RandomResizeCrop and other transformations.
.. note::
The image size of the source domain and the target domain need to be the same, thus before FourierTransform,
you should use Resize to convert the source image to the target image size.
Examples:
>>> from tllib.translation.fourier_transform import FourierTransform
>>> import tllibvision.datasets.segmentation.transforms as T
>>> from PIL import Image
>>> target_image_list = ["target_image_path1", "target_image_path2"]
>>> amplitude_dir = "path/to/amplitude_dir"
>>> # build a fourier transform that translate source images to the target style
>>> fourier_transform = T.wrapper(FourierTransform)(target_image_list, amplitude_dir)
>>> transforms=T.Compose([
... # convert source image to the size of the target image before fourier transform
... T.Resize((2048, 1024)),
... fourier_transform,
... T.RandomResizedCrop((1024, 512)),
... T.RandomHorizontalFlip(),
... ])
>>> source_image = Image.open("path/to/source_image") # image form source domain
>>> source_image_in_target_style = transforms(source_image)
"""
# TODO add image examples when beta is different
def __init__(self, image_list: Sequence[str], amplitude_dir: str,
beta: Optional[int] = 1, rebuild: Optional[bool] = False):
super(FourierTransform, self).__init__()
self.amplitude_dir = amplitude_dir
if not os.path.exists(amplitude_dir) or rebuild:
os.makedirs(amplitude_dir, exist_ok=True)
self.build_amplitude(image_list, amplitude_dir)
self.beta = beta
self.length = len(image_list)
@staticmethod
def build_amplitude(image_list, amplitude_dir):
# extract amplitudes from target domain
for i, image_name in enumerate(tqdm.tqdm(image_list)):
image = Image.open(image_name).convert('RGB')
image = np.asarray(image, np.float32)
image = image.transpose((2, 0, 1))
fft = np.fft.fft2(image, axes=(-2, -1))
amp = np.abs(fft)
np.save(os.path.join(amplitude_dir, "{}.npy".format(i)), amp)
def forward(self, image):
# randomly sample a target image and load its amplitude component
amp_trg = np.load(os.path.join(self.amplitude_dir, "{}.npy".format(random.randint(0, self.length-1))))
image = np.asarray(image, np.float32)
image = image.transpose((2, 0, 1))
# get fft, amplitude on source domain
fft_src = np.fft.fft2(image, axes=(-2, -1))
amp_src, pha_src = np.abs(fft_src), np.angle(fft_src)
# mutate the amplitude part of source with target
amp_src_ = low_freq_mutate(amp_src, amp_trg, beta=self.beta)
# mutated fft of source
fft_src_ = amp_src_ * np.exp(1j * pha_src)
# get the mutated image
src_in_trg = np.fft.ifft2(fft_src_, axes=(-2, -1))
src_in_trg = np.real(src_in_trg)
src_in_trg = src_in_trg.transpose((1, 2, 0))
src_in_trg = Image.fromarray(src_in_trg.clip(min=0, max=255).astype('uint8')).convert('RGB')
return src_in_trg
================================================
FILE: tllib/translation/spgan/__init__.py
================================================
from . import siamese
from . import loss
from .siamese import *
from .loss import *
================================================
FILE: tllib/translation/spgan/loss.py
================================================
"""
Modified from https://github.com/Simon4Yan/eSPGAN
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
import torch.nn.functional as F
class ContrastiveLoss(torch.nn.Module):
r"""Contrastive loss from `Dimensionality Reduction by Learning an Invariant Mapping (CVPR 2006)
`_.
Given output features :math:`f_1, f_2`, we use :math:`D` to denote the pairwise euclidean distance between them,
:math:`Y` to denote the ground truth labels, :math:`m` to denote a pre-defined margin, then contrastive loss is
calculated as
.. math::
(1 - Y)\frac{1}{2}D^2 + (Y)\frac{1}{2}\{\text{max}(0, m-D)^2\}
Args:
margin (float, optional): margin for contrastive loss. Default: 2.0
Inputs:
- output1 (tensor): feature representations of the first set of samples (:math:`f_1` here).
- output2 (tensor): feature representations of the second set of samples (:math:`f_2` here).
- label (tensor): labels (:math:`Y` here).
Shape:
- output1, output2: :math:`(minibatch, F)` where F means the dimension of input features.
- label: :math:`(minibatch, )`
"""
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss
================================================
FILE: tllib/translation/spgan/siamese.py
================================================
"""
Modified from https://github.com/Simon4Yan/eSPGAN
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
"""Basic block with structure Conv-LeakyReLU->Pool"""
def __init__(self, in_dim, out_dim):
super(ConvBlock, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.MaxPool2d(kernel_size=2, stride=2)
)
def forward(self, x):
return self.conv_block(x)
class SiameseNetwork(nn.Module):
"""Siamese network whose input is an image of shape :math:`(3,H,W)` and output is an one-dimensional feature vector.
Args:
nsf (int): dimension of output feature representation.
"""
def __init__(self, nsf=64):
super(SiameseNetwork, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, nsf, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvBlock(nsf, nsf * 2),
ConvBlock(nsf * 2, nsf * 4),
)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(2048, nsf * 2, bias=False)
self.leaky_relu = nn.LeakyReLU(0.2)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(nsf * 2, nsf, bias=False)
def forward(self, x):
x = self.flatten(self.conv(x))
x = self.fc1(x)
x = self.leaky_relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.normalize(x)
return x
================================================
FILE: tllib/utils/__init__.py
================================================
from .logger import CompleteLogger
from .meter import *
from .data import ForeverDataIterator
__all__ = ['metric', 'analysis', 'meter', 'data', 'logger']
================================================
FILE: tllib/utils/analysis/__init__.py
================================================
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import tqdm
def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module,
device: torch.device, max_num_features=None) -> torch.Tensor:
"""
Fetch data from `data_loader`, and then use `feature_extractor` to collect features
Args:
data_loader (torch.utils.data.DataLoader): Data loader.
feature_extractor (torch.nn.Module): A feature extractor.
device (torch.device)
max_num_features (int): The max number of features to return
Returns:
Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`).
"""
feature_extractor.eval()
all_features = []
with torch.no_grad():
for i, data in enumerate(tqdm.tqdm(data_loader)):
if max_num_features is not None and i >= max_num_features:
break
inputs = data[0].to(device)
feature = feature_extractor(inputs).cpu()
all_features.append(feature)
return torch.cat(all_features, dim=0)
================================================
FILE: tllib/utils/analysis/a_distance.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from torch.utils.data import TensorDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import SGD
from ..meter import AverageMeter
from ..metric import binary_accuracy
class ANet(nn.Module):
def __init__(self, in_feature):
super(ANet, self).__init__()
self.layer = nn.Linear(in_feature, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.layer(x)
x = self.sigmoid(x)
return x
def calculate(source_feature: torch.Tensor, target_feature: torch.Tensor,
device, progress=True, training_epochs=10):
"""
Calculate the :math:`\mathcal{A}`-distance, which is a measure for distribution discrepancy.
The definition is :math:`dist_\mathcal{A} = 2 (1-2\epsilon)`, where :math:`\epsilon` is the
test error of a classifier trained to discriminate the source from the target.
Args:
source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`
target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`
device (torch.device)
progress (bool): if True, displays a the progress of training A-Net
training_epochs (int): the number of epochs when training the classifier
Returns:
:math:`\mathcal{A}`-distance
"""
source_label = torch.ones((source_feature.shape[0], 1))
target_label = torch.zeros((target_feature.shape[0], 1))
feature = torch.cat([source_feature, target_feature], dim=0)
label = torch.cat([source_label, target_label], dim=0)
dataset = TensorDataset(feature, label)
length = len(dataset)
train_size = int(0.8 * length)
val_size = length - train_size
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_set, batch_size=2, shuffle=True)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False)
anet = ANet(feature.shape[1]).to(device)
optimizer = SGD(anet.parameters(), lr=0.01)
a_distance = 2.0
for epoch in range(training_epochs):
anet.train()
for (x, label) in train_loader:
x = x.to(device)
label = label.to(device)
anet.zero_grad()
y = anet(x)
loss = F.binary_cross_entropy(y, label)
loss.backward()
optimizer.step()
anet.eval()
meter = AverageMeter("accuracy", ":4.2f")
with torch.no_grad():
for (x, label) in val_loader:
x = x.to(device)
label = label.to(device)
y = anet(x)
acc = binary_accuracy(y, label)
meter.update(acc, x.shape[0])
error = 1 - meter.avg / 100
a_distance = 2 * (1 - 2 * error)
if progress:
print("epoch {} accuracy: {} A-dist: {}".format(epoch, meter.avg, a_distance))
return a_distance
================================================
FILE: tllib/utils/analysis/tsne.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
import matplotlib
matplotlib.use('Agg')
from sklearn.manifold import TSNE
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as col
def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor,
filename: str, source_color='r', target_color='b'):
"""
Visualize features from different domains using t-SNE.
Args:
source_feature (tensor): features from source domain in shape :math:`(minibatch, F)`
target_feature (tensor): features from target domain in shape :math:`(minibatch, F)`
filename (str): the file name to save t-SNE
source_color (str): the color of the source features. Default: 'r'
target_color (str): the color of the target features. Default: 'b'
"""
source_feature = source_feature.numpy()
target_feature = target_feature.numpy()
features = np.concatenate([source_feature, target_feature], axis=0)
# map features to 2-d using TSNE
X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features)
# domain labels, 1 represents source while 0 represents target
domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature))))
# visualize using matplotlib
fig, ax = plt.subplots(figsize=(10, 10))
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(False)
ax.spines['left'].set_visible(False)
plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20)
plt.xticks([])
plt.yticks([])
plt.savefig(filename)
================================================
FILE: tllib/utils/data.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import itertools
import random
import numpy as np
import torch
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, Dataset
from typing import TypeVar, Iterable, Dict, List
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
def send_to_device(tensor, device):
"""
Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device.
Args:
tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`):
The data to send to a given device.
device (:obj:`torch.device`):
The device to send the data to
Returns:
The same data structure as :obj:`tensor` with all tensors sent to the proper device.
"""
if isinstance(tensor, (list, tuple)):
return type(tensor)(send_to_device(t, device) for t in tensor)
elif isinstance(tensor, dict):
return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()})
elif not hasattr(tensor, "to"):
return tensor
return tensor.to(device)
class ForeverDataIterator:
r"""A data iterator that will never stop producing data"""
def __init__(self, data_loader: DataLoader, device=None):
self.data_loader = data_loader
self.iter = iter(self.data_loader)
self.device = device
def __next__(self):
try:
data = next(self.iter)
if self.device is not None:
data = send_to_device(data, self.device)
except StopIteration:
self.iter = iter(self.data_loader)
data = next(self.iter)
if self.device is not None:
data = send_to_device(data, self.device)
return data
def __len__(self):
return len(self.data_loader)
class RandomMultipleGallerySampler(Sampler):
r"""Sampler from `In defense of the Triplet Loss for Person Re-Identification
(ICCV 2017) `_. Assume there are :math:`N` identities in the dataset, this
implementation simply samples :math:`K` images for every identity to form an iter of size :math:`N\times K`. During
training, we will call ``__iter__`` method of pytorch dataloader once we reach a ``StopIteration``, this guarantees
every image in the dataset will eventually be selected and we are not wasting any training data.
Args:
dataset(list): each element of this list is a tuple (image_path, person_id, camera_id)
num_instances(int, optional): number of images to sample for every identity (:math:`K` here)
"""
def __init__(self, dataset, num_instances=4):
super(RandomMultipleGallerySampler, self).__init__(dataset)
self.dataset = dataset
self.num_instances = num_instances
self.idx_to_pid = {}
self.cid_list_per_pid = {}
self.idx_list_per_pid = {}
for idx, (_, pid, cid) in enumerate(dataset):
if pid not in self.cid_list_per_pid:
self.cid_list_per_pid[pid] = []
self.idx_list_per_pid[pid] = []
self.idx_to_pid[idx] = pid
self.cid_list_per_pid[pid].append(cid)
self.idx_list_per_pid[pid].append(idx)
self.pid_list = list(self.idx_list_per_pid.keys())
self.num_samples = len(self.pid_list)
def __len__(self):
return self.num_samples * self.num_instances
def __iter__(self):
def select_idxes(element_list, target_element):
assert isinstance(element_list, list)
return [i for i, element in enumerate(element_list) if element != target_element]
pid_idxes = torch.randperm(len(self.pid_list)).tolist()
final_idxes = []
for perm_id in pid_idxes:
i = random.choice(self.idx_list_per_pid[self.pid_list[perm_id]])
_, _, cid = self.dataset[i]
final_idxes.append(i)
pid_i = self.idx_to_pid[i]
cid_list = self.cid_list_per_pid[pid_i]
idx_list = self.idx_list_per_pid[pid_i]
selected_cid_list = select_idxes(cid_list, cid)
if selected_cid_list:
if len(selected_cid_list) >= self.num_instances:
cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=False)
else:
cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=True)
for cid_idx in cid_idxes:
final_idxes.append(idx_list[cid_idx])
else:
selected_idxes = select_idxes(idx_list, i)
if not selected_idxes:
continue
if len(selected_idxes) >= self.num_instances:
pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=False)
else:
pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=True)
for pid_idx in pid_idxes:
final_idxes.append(idx_list[pid_idx])
return iter(final_idxes)
class CombineDataset(Dataset[T_co]):
r"""Dataset as a combination of multiple datasets.
The element of each dataset must be a list, and the i-th element of the combined dataset
is a list splicing of the i-th element of each sub dataset.
The length of the combined dataset is the minimum of the lengths of all sub datasets.
Arguments:
datasets (sequence): List of datasets to be concatenated
"""
def __init__(self, datasets: Iterable[Dataset]) -> None:
super(CombineDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore
self.datasets = list(datasets)
def __len__(self):
return min([len(d) for d in self.datasets])
def __getitem__(self, idx):
return list(itertools.chain(*[d[idx] for d in self.datasets]))
def concatenate(tensors):
"""concatenate multiple batches into one batch.
``tensors`` can be :class:`torch.Tensor`, List or Dict, but they must be the same data format.
"""
if isinstance(tensors[0], torch.Tensor):
return torch.cat(tensors, dim=0)
elif isinstance(tensors[0], List):
ret = []
for i in range(len(tensors[0])):
ret.append(concatenate([t[i] for t in tensors]))
return ret
elif isinstance(tensors[0], Dict):
ret = dict()
for k in tensors[0].keys():
ret[k] = concatenate([t[k] for t in tensors])
return ret
================================================
FILE: tllib/utils/logger.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
import sys
import time
class TextLogger(object):
"""Writes stream output to external text file.
Args:
filename (str): the file to write stream output
stream: the stream to read from. Default: sys.stdout
"""
def __init__(self, filename, stream=sys.stdout):
self.terminal = stream
self.log = open(filename, 'a')
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.flush()
def flush(self):
self.terminal.flush()
self.log.flush()
def close(self):
self.terminal.close()
self.log.close()
class CompleteLogger:
"""
A useful logger that
- writes outputs to files and displays them on the console at the same time.
- manages the directory of checkpoints and debugging images.
Args:
root (str): the root directory of logger
phase (str): the phase of training.
"""
def __init__(self, root, phase='train'):
self.root = root
self.phase = phase
self.visualize_directory = os.path.join(self.root, "visualize")
self.checkpoint_directory = os.path.join(self.root, "checkpoints")
self.epoch = 0
os.makedirs(self.root, exist_ok=True)
os.makedirs(self.visualize_directory, exist_ok=True)
os.makedirs(self.checkpoint_directory, exist_ok=True)
# redirect std out
now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time()))
log_filename = os.path.join(self.root, "{}-{}.txt".format(phase, now))
if os.path.exists(log_filename):
os.remove(log_filename)
self.logger = TextLogger(log_filename)
sys.stdout = self.logger
sys.stderr = self.logger
if phase != 'train':
self.set_epoch(phase)
def set_epoch(self, epoch):
"""Set the epoch number. Please use it during training."""
os.makedirs(os.path.join(self.visualize_directory, str(epoch)), exist_ok=True)
self.epoch = epoch
def _get_phase_or_epoch(self):
if self.phase == 'train':
return str(self.epoch)
else:
return self.phase
def get_image_path(self, filename: str):
"""
Get the full image path for a specific filename
"""
return os.path.join(self.visualize_directory, self._get_phase_or_epoch(), filename)
def get_checkpoint_path(self, name=None):
"""
Get the full checkpoint path.
Args:
name (optional): the filename (without file extension) to save checkpoint.
If None, when the phase is ``train``, checkpoint will be saved to ``{epoch}.pth``.
Otherwise, will be saved to ``{phase}.pth``.
"""
if name is None:
name = self._get_phase_or_epoch()
name = str(name)
return os.path.join(self.checkpoint_directory, name + ".pth")
def close(self):
self.logger.close()
================================================
FILE: tllib/utils/meter.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, List
class AverageMeter(object):
r"""Computes and stores the average and current value.
Examples::
>>> # Initialize a meter to record loss
>>> losses = AverageMeter()
>>> # Update meter after every minibatch update
>>> losses.update(loss_value, batch_size)
"""
def __init__(self, name: str, fmt: Optional[str] = ':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
if self.count > 0:
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class AverageMeterDict(object):
def __init__(self, names: List, fmt: Optional[str] = ':f'):
self.dict = {
name: AverageMeter(name, fmt) for name in names
}
def reset(self):
for meter in self.dict.values():
meter.reset()
def update(self, accuracies, n=1):
for name, acc in accuracies.items():
self.dict[name].update(acc, n)
def average(self):
return {
name: meter.avg for name, meter in self.dict.items()
}
def __getitem__(self, item):
return self.dict[item]
class Meter(object):
"""Computes and stores the current value."""
def __init__(self, name: str, fmt: Optional[str] = ':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
def update(self, val):
self.val = val
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
================================================
FILE: tllib/utils/metric/__init__.py
================================================
import torch
import prettytable
__all__ = ['keypoint_detection']
def binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float:
"""Computes the accuracy for binary classification"""
with torch.no_grad():
batch_size = target.size(0)
pred = (output >= 0.5).float().t().view(-1)
correct = pred.eq(target.view(-1)).float().sum()
correct.mul_(100. / batch_size)
return correct
def accuracy(output, target, topk=(1,)):
r"""
Computes the accuracy over the k top predictions for the specified values of k
Args:
output (tensor): Classification outputs, :math:`(N, C)` where `C = number of classes`
target (tensor): :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`
topk (sequence[int]): A list of top-N number.
Returns:
Top-N accuracies (N :math:`\in` topK).
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target[None])
res = []
for k in topk:
correct_k = correct[:k].flatten().sum(dtype=torch.float32)
res.append(correct_k * (100.0 / batch_size))
return res
class ConfusionMatrix(object):
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, target, output):
"""
Update confusion matrix.
Args:
target: ground truth
output: predictions of models
Shape:
- target: :math:`(minibatch, C)` where C means the number of classes.
- output: :math:`(minibatch, C)` where C means the number of classes.
"""
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device)
with torch.no_grad():
k = (target >= 0) & (target < n)
inds = n * target[k].to(torch.int64) + output[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
self.mat.zero_()
def compute(self):
"""compute global accuracy, per-class accuracy and per-class IoU"""
h = self.mat.float()
acc_global = torch.diag(h).sum() / h.sum()
acc = torch.diag(h) / h.sum(1)
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
return acc_global, acc, iu
# def reduce_from_all_processes(self):
# if not torch.distributed.is_available():
# return
# if not torch.distributed.is_initialized():
# return
# torch.distributed.barrier()
# torch.distributed.all_reduce(self.mat)
def __str__(self):
acc_global, acc, iu = self.compute()
return (
'global correct: {:.1f}\n'
'average row correct: {}\n'
'IoU: {}\n'
'mean IoU: {:.1f}').format(
acc_global.item() * 100,
['{:.1f}'.format(i) for i in (acc * 100).tolist()],
['{:.1f}'.format(i) for i in (iu * 100).tolist()],
iu.mean().item() * 100)
def format(self, classes: list):
"""Get the accuracy and IoU for each class in the table format"""
acc_global, acc, iu = self.compute()
table = prettytable.PrettyTable(["class", "acc", "iou"])
for i, class_name, per_acc, per_iu in zip(range(len(classes)), classes, (acc * 100).tolist(), (iu * 100).tolist()):
table.add_row([class_name, per_acc, per_iu])
return 'global correct: {:.1f}\nmean correct:{:.1f}\nmean IoU: {:.1f}\n{}'.format(
acc_global.item() * 100, acc.mean().item() * 100, iu.mean().item() * 100, table.get_string())
================================================
FILE: tllib/utils/metric/keypoint_detection.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
# TODO: add documentation
import numpy as np
def get_max_preds(batch_heatmaps):
'''
get predictions from score maps
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
'''
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
idx = np.argmax(heatmaps_reshaped, 2)
maxvals = np.amax(heatmaps_reshaped, 2)
maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))
preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)
preds *= pred_mask
return preds, maxvals
def calc_dists(preds, target, normalize):
preds = preds.astype(np.float32)
target = target.astype(np.float32)
dists = np.zeros((preds.shape[1], preds.shape[0]))
for n in range(preds.shape[0]):
for c in range(preds.shape[1]):
if target[n, c, 0] > 1 and target[n, c, 1] > 1:
normed_preds = preds[n, c, :] / normalize[n]
normed_targets = target[n, c, :] / normalize[n]
dists[c, n] = np.linalg.norm(normed_preds - normed_targets)
else:
dists[c, n] = -1
return dists
def dist_acc(dists, thr=0.5):
''' Return percentage below threshold while ignoring values with a -1 '''
dist_cal = np.not_equal(dists, -1)
num_dist_cal = dist_cal.sum()
if num_dist_cal > 0:
return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal
else:
return -1
def accuracy(output, target, hm_type='gaussian', thr=0.5):
'''
Calculate accuracy according to PCK,
but uses ground truth heatmap rather than x,y locations
First value to be returned is average accuracy across 'idxs',
followed by individual accuracies
'''
idx = list(range(output.shape[1]))
norm = 1.0
if hm_type == 'gaussian':
pred, _ = get_max_preds(output)
target, _ = get_max_preds(target)
h = output.shape[2]
w = output.shape[3]
norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10
dists = calc_dists(pred, target, norm)
acc = np.zeros(len(idx))
avg_acc = 0
cnt = 0
for i in range(len(idx)):
acc[i] = dist_acc(dists[idx[i]], thr)
if acc[i] >= 0:
avg_acc = avg_acc + acc[i]
cnt += 1
avg_acc = avg_acc / cnt if cnt != 0 else 0
return acc, avg_acc, cnt, pred
================================================
FILE: tllib/utils/metric/reid.py
================================================
# TODO: add documentation
"""
Modified from https://github.com/yxgeee/MMT
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import os
import os.path as osp
from collections import defaultdict
import time
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import average_precision_score
from tllib.utils.meter import AverageMeter, ProgressMeter
def unique_sample(ids_dict, num):
"""Randomly choose one instance for each person id, these instances will not be selected again"""
mask = np.zeros(num, dtype=np.bool)
for _, indices in ids_dict.items():
i = np.random.choice(indices)
mask[i] = True
return mask
def cmc(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams, topk=100, separate_camera_set=False,
single_gallery_shot=False, first_match_break=False):
"""Compute Cumulative Matching Characteristics (CMC)"""
dist_mat = dist_mat.cpu().numpy()
m, n = dist_mat.shape
query_ids = np.asarray(query_ids)
gallery_ids = np.asarray(gallery_ids)
query_cams = np.asarray(query_cams)
gallery_cams = np.asarray(gallery_cams)
# Sort and find correct matches
indices = np.argsort(dist_mat, axis=1)
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
# Compute CMC for each query
ret = np.zeros(topk)
num_valid_queries = 0
for i in range(m):
# Filter out the same id and same camera
valid = ((gallery_ids[indices[i]] != query_ids[i]) |
(gallery_cams[indices[i]] != query_cams[i]))
if separate_camera_set:
# Filter out samples from same camera
valid &= (gallery_cams[indices[i]] != query_cams[i])
if not np.any(matches[i, valid]): continue
if single_gallery_shot:
repeat = 10
gids = gallery_ids[indices[i][valid]]
inds = np.where(valid)[0]
ids_dict = defaultdict(list)
for j, x in zip(inds, gids):
ids_dict[x].append(j)
else:
repeat = 1
for _ in range(repeat):
if single_gallery_shot:
# Randomly choose one instance for each id
sampled = (valid & unique_sample(ids_dict, len(valid)))
index = np.nonzero(matches[i, sampled])[0]
else:
index = np.nonzero(matches[i, valid])[0]
delta = 1. / (len(index) * repeat)
for j, k in enumerate(index):
if k - j >= topk: break
if first_match_break:
ret[k - j] += 1
break
ret[k - j] += delta
num_valid_queries += 1
if num_valid_queries == 0:
raise RuntimeError("No valid query")
return ret.cumsum() / num_valid_queries
def mean_ap(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams):
"""Compute mean average precision (mAP)"""
dist_mat = dist_mat.cpu().numpy()
m, n = dist_mat.shape
query_ids = np.asarray(query_ids)
gallery_ids = np.asarray(gallery_ids)
query_cams = np.asarray(query_cams)
gallery_cams = np.asarray(gallery_cams)
# Sort and find correct matches
indices = np.argsort(dist_mat, axis=1)
matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
# Compute AP for each query
aps = []
for i in range(m):
# Filter out the same id and same camera
valid = ((gallery_ids[indices[i]] != query_ids[i]) |
(gallery_cams[indices[i]] != query_cams[i]))
y_true = matches[i, valid]
y_score = -dist_mat[i][indices[i]][valid]
if not np.any(y_true): continue
aps.append(average_precision_score(y_true, y_score))
if len(aps) == 0:
raise RuntimeError("No valid query")
return np.mean(aps)
def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3):
"""Perform re-ranking with distance matrix between query and gallery images `q_g_dist`, distance matrix between
query and query images `q_q_dist` and distance matrix between gallery and gallery images `g_g_dist`.
"""
q_g_dist = q_g_dist.cpu().numpy()
q_q_dist = q_q_dist.cpu().numpy()
g_g_dist = g_g_dist.cpu().numpy()
original_dist = np.concatenate(
[np.concatenate([q_q_dist, q_g_dist], axis=1),
np.concatenate([q_g_dist.T, g_g_dist], axis=1)],
axis=0)
original_dist = np.power(original_dist, 2).astype(np.float32)
original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0))
V = np.zeros_like(original_dist).astype(np.float32)
initial_rank = np.argsort(original_dist).astype(np.int32)
query_num = q_g_dist.shape[0]
gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1]
all_num = gallery_num
for i in range(all_num):
# k-reciprocal neighbors
forward_k_neigh_index = initial_rank[i, :k1 + 1]
backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1]
fi = np.where(backward_k_neigh_index == i)[0]
k_reciprocal_index = forward_k_neigh_index[fi]
k_reciprocal_expansion_index = k_reciprocal_index
for j in range(len(k_reciprocal_index)):
candidate = k_reciprocal_index[j]
candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2.)) + 1]
candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index,
:int(np.around(k1 / 2.)) + 1]
fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0]
candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate]
if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2. / 3 * len(
candidate_k_reciprocal_index):
k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index)
k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index)
weight = np.exp(-original_dist[i, k_reciprocal_expansion_index])
V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight)
original_dist = original_dist[:query_num, ]
if k2 != 1:
V_qe = np.zeros_like(V, dtype=np.float32)
for i in range(all_num):
V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
V = V_qe
del V_qe
del initial_rank
invIndex = []
for i in range(gallery_num):
invIndex.append(np.where(V[:, i] != 0)[0])
jaccard_dist = np.zeros_like(original_dist, dtype=np.float32)
for i in range(query_num):
temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32)
indNonZero = np.where(V[i, :] != 0)[0]
indImages = [invIndex[ind] for ind in indNonZero]
for j in range(len(indNonZero)):
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]],
V[indImages[j], indNonZero[j]])
jaccard_dist[i] = 1 - temp_min / (2. - temp_min)
final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value
del original_dist
del V
del jaccard_dist
final_dist = final_dist[:query_num, query_num:]
return final_dist
def extract_reid_feature(data_loader, model, device, normalize, print_freq=200):
"""Extract feature for person ReID. If `normalize` is True, `cosine` distance will be employed as distance
metric, otherwise `euclidean` distance.
"""
batch_time = AverageMeter('Time', ':6.3f')
progress = ProgressMeter(
len(data_loader),
[batch_time],
prefix='Collect feature: ')
# switch to eval mode
model.eval()
feature_dict = dict()
with torch.no_grad():
end = time.time()
for i, (images_batch, filenames_batch, _, _) in enumerate(data_loader):
images_batch = images_batch.to(device)
features_batch = model(images_batch)
if normalize:
features_batch = F.normalize(features_batch)
for filename, feature in zip(filenames_batch, features_batch):
feature_dict[filename] = feature
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % print_freq == 0:
progress.display(i)
return feature_dict
def pairwise_distance(feature_dict, query, gallery):
"""Compute pairwise distance between two sets of features"""
# concat features and convert to pytorch tensor
# we compute pairwise distance metric on cpu because it may require a large amount of GPU memory, if you are using
# gpu with a larger capacity, it's faster to calculate on gpu
x = torch.cat([feature_dict[f].unsqueeze(0) for f, _, _ in query], dim=0).cpu()
y = torch.cat([feature_dict[f].unsqueeze(0) for f, _, _ in gallery], dim=0).cpu()
m, n = x.size(0), y.size(0)
# flatten
x = x.view(m, -1)
y = y.view(n, -1)
# compute dist_mat
dist_mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() - \
2 * torch.matmul(x, y.t())
return dist_mat
def evaluate_all(dist_mat, query, gallery, cmc_topk=(1, 5, 10), cmc_flag=False):
"""Compute CMC score, mAP and return"""
query_ids = [pid for _, pid, _ in query]
gallery_ids = [pid for _, pid, _ in gallery]
query_cams = [cid for _, _, cid in query]
gallery_cams = [cid for _, _, cid in gallery]
# Compute mean AP
mAP = mean_ap(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams)
print('Mean AP: {:4.1%}'.format(mAP))
if not cmc_flag:
return mAP
cmc_configs = {
'config': dict(separate_camera_set=False, single_gallery_shot=False, first_match_break=True)
}
cmc_scores = {name: cmc(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams, **params) for name, params in
cmc_configs.items()}
print('CMC Scores:')
for k in cmc_topk:
print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['config'][k - 1]))
return cmc_scores['config'][0], mAP
def validate(val_loader, model, query, gallery, device, criterion='cosine', cmc_flag=False, rerank=False):
assert criterion in ['cosine', 'euclidean']
# when criterion == 'cosine', normalize feature of single image into unit norm
normalize = (criterion == 'cosine')
feature_dict = extract_reid_feature(val_loader, model, device, normalize)
dist_mat = pairwise_distance(feature_dict, query, gallery)
results = evaluate_all(dist_mat, query=query, gallery=gallery, cmc_flag=cmc_flag)
if not rerank:
return results
# apply person re-ranking
print('Applying person re-ranking')
dist_mat_query = pairwise_distance(feature_dict, query, query)
dist_mat_gallery = pairwise_distance(feature_dict, gallery, gallery)
dist_mat = re_ranking(dist_mat, dist_mat_query, dist_mat_gallery)
return evaluate_all(dist_mat, query=query, gallery=gallery, cmc_flag=cmc_flag)
# location parameters for visualization
GRID_SPACING = 10
QUERY_EXTRA_SPACING = 90
# border width
BW = 5
GREEN = (0, 255, 0)
RED = (0, 0, 255)
def visualize_ranked_results(data_loader, model, query, gallery, device, visualize_dir, criterion='cosine',
rerank=False, width=128, height=256, topk=10):
"""Visualize ranker results. We first compute pair-wise distance between query images and gallery images. Then for
every query image, `topk` gallery images with least distance between given query image are selected. We plot the
query image and selected gallery images together. A green border denotes a match, and a red one denotes a mis-match.
"""
assert criterion in ['cosine', 'euclidean']
normalize = (criterion == 'cosine')
# compute pairwise distance matrix
feature_dict = extract_reid_feature(data_loader, model, device, normalize)
dist_mat = pairwise_distance(feature_dict, query, gallery)
if rerank:
dist_mat_query = pairwise_distance(feature_dict, query, query)
dist_mat_gallery = pairwise_distance(feature_dict, gallery, gallery)
dist_mat = re_ranking(dist_mat, dist_mat_query, dist_mat_gallery)
# make dir if not exists
os.makedirs(visualize_dir, exist_ok=True)
dist_mat = dist_mat.numpy()
num_q, num_g = dist_mat.shape
print('query images: {}'.format(num_q))
print('gallery images: {}'.format(num_g))
assert num_q == len(query)
assert num_g == len(gallery)
# start visualizing
import cv2
sorted_idxes = np.argsort(dist_mat, axis=1)
for q_idx in range(num_q):
q_img_path, q_pid, q_cid = query[q_idx]
q_img = cv2.imread(q_img_path)
q_img = cv2.resize(q_img, (width, height))
# use black border to denote query image
q_img = cv2.copyMakeBorder(
q_img, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=(0, 0, 0)
)
q_img = cv2.resize(q_img, (width, height))
num_cols = topk + 1
grid_img = 255 * np.ones(
(height, num_cols * width + topk * GRID_SPACING + QUERY_EXTRA_SPACING, 3), dtype=np.uint8
)
grid_img[:, :width, :] = q_img
# collect top-k gallery images with smallest distance
rank_idx = 1
for g_idx in sorted_idxes[q_idx, :]:
g_img_path, g_pid, g_cid = gallery[g_idx]
invalid = (q_pid == g_pid) & (q_cid == g_cid)
if not invalid:
matched = (g_pid == q_pid)
border_color = GREEN if matched else RED
g_img = cv2.imread(g_img_path)
g_img = cv2.resize(g_img, (width, height))
g_img = cv2.copyMakeBorder(
g_img, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=border_color
)
g_img = cv2.resize(g_img, (width, height))
start = rank_idx * width + rank_idx * GRID_SPACING + QUERY_EXTRA_SPACING
end = (rank_idx + 1) * width + rank_idx * GRID_SPACING + QUERY_EXTRA_SPACING
grid_img[:, start:end, :] = g_img
rank_idx += 1
if rank_idx > topk:
break
save_path = osp.basename(osp.splitext(q_img_path)[0])
cv2.imwrite(osp.join(visualize_dir, save_path + '.jpg'), grid_img)
if (q_idx + 1) % 100 == 0:
print('Visualize {}/{}'.format(q_idx + 1, num_q))
print('Visualization process is done, ranked results are saved to {}'.format(visualize_dir))
================================================
FILE: tllib/utils/scheduler.py
================================================
"""
Modified from https://github.com/yxgeee/MMT
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
from bisect import bisect_right
class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
r"""Starts with a warm-up phase, then decays the learning rate of each parameter group by gamma once the
number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of learning rate decay.
Default: 0.1.
warmup_factor (float): a float number :math:`k` between 0 and 1, the start learning rate of warmup phase
will be set to :math:`k*initial\_lr`
warmup_steps (int): number of warm-up steps.
warmup_method (str): "constant" denotes a constant learning rate during warm-up phase and "linear" denotes a
linear-increasing learning rate during warm-up phase.
last_epoch (int): The index of last epoch. Default: -1.
"""
def __init__(
self,
optimizer,
milestones,
gamma=0.1,
warmup_factor=1.0 / 3,
warmup_steps=500,
warmup_method="linear",
last_epoch=-1,
):
if not list(milestones) == sorted(milestones):
raise ValueError(
"Milestones should be a list of" " increasing integers. Got {}",
milestones,
)
if warmup_method not in ("constant", "linear"):
raise ValueError(
"Only 'constant' or 'linear' warmup_method accepted"
"got {}".format(warmup_method)
)
self.milestones = milestones
self.gamma = gamma
self.warmup_factor = warmup_factor
self.warmup_steps = warmup_steps
self.warmup_method = warmup_method
super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
warmup_factor = 1
if self.last_epoch < self.warmup_steps:
if self.warmup_method == "constant":
warmup_factor = self.warmup_factor
elif self.warmup_method == "linear":
alpha = float(self.last_epoch) / float(self.warmup_steps)
warmup_factor = self.warmup_factor * (1 - alpha) + alpha
return [
base_lr
* warmup_factor
* self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]
================================================
FILE: tllib/vision/__init__.py
================================================
__all__ = ['datasets', 'models', 'transforms']
================================================
FILE: tllib/vision/datasets/__init__.py
================================================
from .imagelist import ImageList
from .office31 import Office31
from .officehome import OfficeHome
from .visda2017 import VisDA2017
from .officecaltech import OfficeCaltech
from .domainnet import DomainNet
from .imagenet_r import ImageNetR
from .imagenet_sketch import ImageNetSketch
from .pacs import PACS
from .digits import *
from .aircrafts import Aircraft
from .cub200 import CUB200
from .stanford_cars import StanfordCars
from .stanford_dogs import StanfordDogs
from .coco70 import COCO70
from .oxfordpets import OxfordIIITPets
from .dtd import DTD
from .oxfordflowers import OxfordFlowers102
from .patchcamelyon import PatchCamelyon
from .retinopathy import Retinopathy
from .eurosat import EuroSAT
from .resisc45 import Resisc45
from .food101 import Food101
from .sun397 import SUN397
from .caltech101 import Caltech101
from .cifar import CIFAR10, CIFAR100
__all__ = ['ImageList', 'Office31', 'OfficeHome', "VisDA2017", "OfficeCaltech", "DomainNet", "ImageNetR",
"ImageNetSketch", "Aircraft", "cub200", "StanfordCars", "StanfordDogs", "COCO70", "OxfordIIITPets", "PACS",
"DTD", "OxfordFlowers102", "PatchCamelyon", "Retinopathy", "EuroSAT", "Resisc45", "Food101", "SUN397",
"Caltech101", "CIFAR10", "CIFAR100"]
================================================
FILE: tllib/vision/datasets/_util.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from typing import List
from torchvision.datasets.utils import download_and_extract_archive
def download(root: str, file_name: str, archive_name: str, url_link: str):
"""
Download file from internet url link.
Args:
root (str) The directory to put downloaded files.
file_name: (str) The name of the unzipped file.
archive_name: (str) The name of archive(zipped file) downloaded.
url_link: (str) The url link to download data.
.. note::
If `file_name` already exists under path `root`, then it is not downloaded again.
Else `archive_name` will be downloaded from `url_link` and extracted to `file_name`.
"""
if not os.path.exists(os.path.join(root, file_name)):
print("Downloading {}".format(file_name))
# if os.path.exists(os.path.join(root, archive_name)):
# os.remove(os.path.join(root, archive_name))
try:
download_and_extract_archive(url_link, download_root=root, filename=archive_name, remove_finished=False)
except Exception:
print("Fail to download {} from url link {}".format(archive_name, url_link))
print('Please check you internet connection.'
"Simply trying again may be fine.")
exit(0)
def check_exits(root: str, file_name: str):
"""Check whether `file_name` exists under directory `root`. """
if not os.path.exists(os.path.join(root, file_name)):
print("Dataset directory {} not found under {}".format(file_name, root))
exit(-1)
def read_list_from_file(file_name: str) -> List[str]:
"""Read data from file and convert each line into an element in the list"""
result = []
with open(file_name, "r") as f:
for line in f.readlines():
result.append(line.strip())
return result
================================================
FILE: tllib/vision/datasets/aircrafts.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class Aircraft(ImageList):
"""`FVGC-Aircraft `_ \
is a benchmark for the fine-grained visual categorization of aircraft. \
The dataset contains 10,200 images of aircraft, with 100 images for each \
of the 102 different aircraft variants.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
sample_rate (int): The sampling rates to sample random ``training`` images for each category.
Choices include 100, 50, 30, 15. Default: 100.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
train/
test/
image_list/
train_100.txt
train_50.txt
train_30.txt
train_15.txt
test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/449157d27987463cbdb1/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/06804f17fdb947aa9401/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/164996d09cc749abbdeb/?dl=1"),
]
image_list = {
"train": "image_list/train_100.txt",
"train100": "image_list/train_100.txt",
"train50": "image_list/train_50.txt",
"train30": "image_list/train_30.txt",
"train15": "image_list/train_15.txt",
"test": "image_list/test.txt",
"test100": "image_list/test.txt",
}
CLASSES = ['707-320', '727-200', '737-200', '737-300', '737-400', '737-500', '737-600', '737-700', '737-800',
'737-900', '747-100', '747-200', '747-300', '747-400', '757-200', '757-300', '767-200', '767-300',
'767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', 'A321', 'A330-200',
'A330-300', 'A340-200', 'A340-300', 'A340-500', 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12',
'BAE 146-200', 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47',
'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525', 'Cessna 560',
'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30', 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100',
'DHC-8-300', 'DR-400', 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145',
'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A-B', 'F-A-18', 'Falcon 2000', 'Falcon 900',
'Fokker 100', 'Fokker 50', 'Fokker 70', 'Global Express', 'Gulfstream IV', 'Gulfstream V',
'Hawk T1', 'Il-76', 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200', 'PA-28',
'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134', 'Tu-154', 'Yak-42']
def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,
**kwargs):
if split == 'train':
list_name = 'train' + str(sample_rate)
assert list_name in self.image_list
data_list_file = os.path.join(root, self.image_list[list_name])
else:
data_list_file = os.path.join(root, self.image_list['test'])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(Aircraft, self).__init__(root, Aircraft.CLASSES, data_list_file=data_list_file, **kwargs)
================================================
FILE: tllib/vision/datasets/caltech101.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class Caltech101(ImageList):
"""`The Caltech101 Dataset `_ contains objects
belonging to 101 categories with about 40 to 800 images per category. Most categories have about 50 images.
The size of each image is roughly 300 x 200 pixels.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/d6d4b813a800403f835e/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/ed4d0de80da246f98171/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/db1c444200a848799683/?dl=1")
]
def __init__(self, root, split='train', download=True, **kwargs):
classes = ['accordion', 'airplanes', 'anchor', 'ant', 'background_google', 'barrel', 'bass', 'beaver',
'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon',
'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face',
'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin',
'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'faces', 'faces_easy',
'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano',
'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree',
'kangaroo', 'ketch', 'lamp', 'laptop', 'leopards', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly',
'menorah', 'metronome', 'minaret', 'motorbikes', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda',
'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner',
'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus',
'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly',
'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang']
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(Caltech101, self).__init__(root, classes, os.path.join(root, 'image_list', '{}.txt'.format(split)),
**kwargs)
================================================
FILE: tllib/vision/datasets/cifar.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from torchvision.datasets.cifar import CIFAR10 as CIFAR10Base, CIFAR100 as CIFAR100Base
class CIFAR10(CIFAR10Base):
"""
`CIFAR10 `_ Dataset.
"""
def __init__(self, root, split='train', transform=None, download=True):
super(CIFAR10, self).__init__(root, train=split == 'train', transform=transform, download=download)
self.num_classes = 10
class CIFAR100(CIFAR100Base):
"""
`CIFAR100 `_ Dataset.
"""
def __init__(self, root, split='train', transform=None, download=True):
super(CIFAR100, self).__init__(root, train=split == 'train', transform=transform, download=download)
self.num_classes = 100
================================================
FILE: tllib/vision/datasets/coco70.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class COCO70(ImageList):
"""COCO-70 dataset is a large-scale classification dataset (1000 images per class) created from
`COCO `_ Dataset.
It is used to explore the effect of fine-tuning with a large amount of data.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
sample_rate (int): The sampling rates to sample random ``training`` images for each category.
Choices include 100, 50, 30, 15. Default: 100.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
train/
test/
image_list/
train_100.txt
train_50.txt
train_30.txt
train_15.txt
test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/b008c0d823ad488c8be1/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/75a895576d5e4e59a88d/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/ec6e45bc830d42f0924a/?dl=1"),
]
image_list = {
"train": "image_list/train_100.txt",
"train100": "image_list/train_100.txt",
"train50": "image_list/train_50.txt",
"train30": "image_list/train_30.txt",
"train15": "image_list/train_15.txt",
"test": "image_list/test.txt",
"test100": "image_list/test.txt",
}
CLASSES =['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
'handbag', 'tie', 'suitcase', 'skis', 'kite', 'baseball_bat', 'skateboard', 'surfboard',
'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana',
'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake',
'chair', 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
'remote', 'keyboard', 'cell_phone', 'microwave', 'oven', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'teddy_bear']
def __init__(self, root: str, split: str, sample_rate: Optional[int] =100, download: Optional[bool] = False, **kwargs):
if split == 'train':
list_name = 'train' + str(sample_rate)
assert list_name in self.image_list
data_list_file = os.path.join(root, self.image_list[list_name])
else:
data_list_file = os.path.join(root, self.image_list['test'])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(COCO70, self).__init__(root, COCO70.CLASSES, data_list_file=data_list_file, **kwargs)
================================================
FILE: tllib/vision/datasets/cub200.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class CUB200(ImageList):
"""`Caltech-UCSD Birds-200-2011 `_ \
is a dataset for fine-grained visual recognition with 11,788 images in 200 bird species. \
It is an extended version of the CUB-200 dataset, roughly doubling the number of images.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
sample_rate (int): The sampling rates to sample random ``training`` images for each category.
Choices include 100, 50, 30, 15. Default: 100.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
train/
test/
image_list/
train_100.txt
train_50.txt
train_30.txt
train_15.txt
test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/c2a5952eb18b466b9fb0/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/63db4c49b57b43198b95/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/72e95cccdcaf4b42b4eb/?dl=1"),
]
image_list = {
"train": "image_list/train_100.txt",
"train100": "image_list/train_100.txt",
"train50": "image_list/train_50.txt",
"train30": "image_list/train_30.txt",
"train15": "image_list/train_15.txt",
"test": "image_list/test.txt",
"test100": "image_list/test.txt",
}
CLASSES = ['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani',
'005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet',
'009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird',
'013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal',
'018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee',
'022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant',
'026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow',
'031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo',
'034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher',
'038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher',
'041.Scissor_tailed_Flycatcher', '042.Vermilion_Flycatcher', '043.Yellow_bellied_Flycatcher',
'044.Frigatebird', '045.Northern_Fulmar', '046.Gadwall', '047.American_Goldfinch',
'048.European_Goldfinch', '049.Boat_tailed_Grackle', '050.Eared_Grebe',
'051.Horned_Grebe', '052.Pied_billed_Grebe', '053.Western_Grebe', '054.Blue_Grosbeak',
'055.Evening_Grosbeak', '056.Pine_Grosbeak', '057.Rose_breasted_Grosbeak', '058.Pigeon_Guillemot',
'059.California_Gull', '060.Glaucous_winged_Gull', '061.Heermann_Gull', '062.Herring_Gull',
'063.Ivory_Gull', '064.Ring_billed_Gull', '065.Slaty_backed_Gull', '066.Western_Gull',
'067.Anna_Hummingbird', '068.Ruby_throated_Hummingbird', '069.Rufous_Hummingbird', '070.Green_Violetear',
'071.Long_tailed_Jaeger', '072.Pomarine_Jaeger', '073.Blue_Jay', '074.Florida_Jay', '075.Green_Jay',
'076.Dark_eyed_Junco', '077.Tropical_Kingbird', '078.Gray_Kingbird', '079.Belted_Kingfisher',
'080.Green_Kingfisher', '081.Pied_Kingfisher', '082.Ringed_Kingfisher', '083.White_breasted_Kingfisher',
'084.Red_legged_Kittiwake', '085.Horned_Lark', '086.Pacific_Loon', '087.Mallard',
'088.Western_Meadowlark', '089.Hooded_Merganser', '090.Red_breasted_Merganser', '091.Mockingbird',
'092.Nighthawk', '093.Clark_Nutcracker', '094.White_breasted_Nuthatch', '095.Baltimore_Oriole',
'096.Hooded_Oriole', '097.Orchard_Oriole', '098.Scott_Oriole', '099.Ovenbird', '100.Brown_Pelican',
'101.White_Pelican', '102.Western_Wood_Pewee', '103.Sayornis', '104.American_Pipit',
'105.Whip_poor_Will', '106.Horned_Puffin', '107.Common_Raven', '108.White_necked_Raven',
'109.American_Redstart', '110.Geococcyx', '111.Loggerhead_Shrike', '112.Great_Grey_Shrike',
'113.Baird_Sparrow', '114.Black_throated_Sparrow', '115.Brewer_Sparrow', '116.Chipping_Sparrow',
'117.Clay_colored_Sparrow', '118.House_Sparrow', '119.Field_Sparrow', '120.Fox_Sparrow',
'121.Grasshopper_Sparrow', '122.Harris_Sparrow', '123.Henslow_Sparrow', '124.Le_Conte_Sparrow',
'125.Lincoln_Sparrow', '126.Nelson_Sharp_tailed_Sparrow', '127.Savannah_Sparrow', '128.Seaside_Sparrow',
'129.Song_Sparrow', '130.Tree_Sparrow', '131.Vesper_Sparrow', '132.White_crowned_Sparrow',
'133.White_throated_Sparrow', '134.Cape_Glossy_Starling', '135.Bank_Swallow', '136.Barn_Swallow',
'137.Cliff_Swallow', '138.Tree_Swallow', '139.Scarlet_Tanager', '140.Summer_Tanager', '141.Artic_Tern',
'142.Black_Tern', '143.Caspian_Tern', '144.Common_Tern', '145.Elegant_Tern', '146.Forsters_Tern',
'147.Least_Tern', '148.Green_tailed_Towhee', '149.Brown_Thrasher', '150.Sage_Thrasher',
'151.Black_capped_Vireo', '152.Blue_headed_Vireo', '153.Philadelphia_Vireo', '154.Red_eyed_Vireo',
'155.Warbling_Vireo', '156.White_eyed_Vireo', '157.Yellow_throated_Vireo', '158.Bay_breasted_Warbler',
'159.Black_and_white_Warbler', '160.Black_throated_Blue_Warbler', '161.Blue_winged_Warbler',
'162.Canada_Warbler', '163.Cape_May_Warbler', '164.Cerulean_Warbler', '165.Chestnut_sided_Warbler',
'166.Golden_winged_Warbler', '167.Hooded_Warbler', '168.Kentucky_Warbler', '169.Magnolia_Warbler',
'170.Mourning_Warbler', '171.Myrtle_Warbler', '172.Nashville_Warbler', '173.Orange_crowned_Warbler',
'174.Palm_Warbler', '175.Pine_Warbler', '176.Prairie_Warbler', '177.Prothonotary_Warbler',
'178.Swainson_Warbler', '179.Tennessee_Warbler', '180.Wilson_Warbler', '181.Worm_eating_Warbler',
'182.Yellow_Warbler', '183.Northern_Waterthrush', '184.Louisiana_Waterthrush', '185.Bohemian_Waxwing',
'186.Cedar_Waxwing', '187.American_Three_toed_Woodpecker', '188.Pileated_Woodpecker',
'189.Red_bellied_Woodpecker', '190.Red_cockaded_Woodpecker', '191.Red_headed_Woodpecker',
'192.Downy_Woodpecker', '193.Bewick_Wren', '194.Cactus_Wren', '195.Carolina_Wren', '196.House_Wren',
'197.Marsh_Wren', '198.Rock_Wren', '199.Winter_Wren', '200.Common_Yellowthroat']
def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,
**kwargs):
if split == 'train':
list_name = 'train' + str(sample_rate)
assert list_name in self.image_list
data_list_file = os.path.join(root, self.image_list[list_name])
else:
data_list_file = os.path.join(root, self.image_list['test'])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(CUB200, self).__init__(root, CUB200.CLASSES, data_list_file=data_list_file, **kwargs)
================================================
FILE: tllib/vision/datasets/digits.py
================================================
"""
@author: Junguang Jiang, Baixu Chen
@contact: JiangJunguang1123@outlook.com, cbx_99_hasta@outlook.com
"""
import os
from typing import Optional, Tuple, Any
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class MNIST(ImageList):
"""`MNIST `_ Dataset.
Args:
root (str): Root directory of dataset where ``MNIST/processed/training.pt``
and ``MNIST/processed/test.pt`` exist.
mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``.
Default: ``"L"```
split (str, optional): The dataset split, supports ``train``, or ``test``.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/16feadf7fb3641c2be9a/?dl=1"),
("mnist_train_image", "mnist_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/c93080af28e54559aeeb/?dl=1"),
# ("mnist_test_image", "mnist_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/c93080af28e54559aeeb/?dl=1")
]
image_list = {
"train": "image_list/mnist_train.txt",
"test": "image_list/mnist_test.txt"
}
CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
def __init__(self, root, mode="L", split='train', download: Optional[bool] = True, **kwargs):
assert split in ['train', 'test']
data_list_file = os.path.join(root, self.image_list[split])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
assert mode in ['L', 'RGB']
self.mode = mode
super(MNIST, self).__init__(root, MNIST.CLASSES, data_list_file=data_list_file, **kwargs)
def __getitem__(self, index: int) -> Tuple[Any, int]:
"""
Args:
index (int): Index
return (tuple): (image, target) where target is index of the target class.
"""
path, target = self.samples[index]
img = self.loader(path).convert(self.mode)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None and target is not None:
target = self.target_transform(target)
return img, target
@classmethod
def get_classes(self):
return MNIST.CLASSES
class USPS(ImageList):
"""`USPS `_ Dataset.
The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``.
The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]``
and make pixel values in ``[0, 255]``.
Args:
root (str): Root directory of dataset to store``USPS`` data files.
mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``.
Default: ``"L"```
split (str, optional): The dataset split, supports ``train``, or ``test``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/721ceaf3c031413cb62f/?dl=1"),
("usps_train_image", "usps_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/c5bd329a00fb4dc79608/?dl=1"),
# ("usps_test_image", "usps_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/c5bd329a00fb4dc79608/?dl=1")
]
image_list = {
"train": "image_list/usps_train.txt",
"test": "image_list/usps_test.txt"
}
CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
def __init__(self, root, mode="L", split='train', download: Optional[bool] = True, **kwargs):
assert split in ['train', 'test']
data_list_file = os.path.join(root, self.image_list[split])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
assert mode in ['L', 'RGB']
self.mode = mode
super(USPS, self).__init__(root, USPS.CLASSES, data_list_file=data_list_file, **kwargs)
def __getitem__(self, index: int) -> Tuple[Any, int]:
"""
Args:
index (int): Index
return (tuple): (image, target) where target is index of the target class.
"""
path, target = self.samples[index]
img = self.loader(path).convert(self.mode)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None and target is not None:
target = self.target_transform(target)
return img, target
class SVHN(ImageList):
"""`SVHN `_ Dataset.
Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
expect the class labels to be in the range `[0, C-1]`
.. warning::
This class needs `scipy `_ to load data from `.mat` format.
Args:
root (str): Root directory of dataset where directory
``SVHN`` exists.
mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``.
Default: ``"RGB"```
split (str, optional): The dataset split, supports ``train``, or ``test``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/12b35fb08f8049f98362/?dl=1"),
("svhn_image", "svhn_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/cc02de6cf81543378cce/?dl=1")
]
image_list = "image_list/svhn_balanced.txt"
# image_list = "image_list/svhn.txt"
CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
'5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
def __init__(self, root, mode="L", download: Optional[bool] = True, **kwargs):
data_list_file = os.path.join(root, self.image_list)
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
assert mode in ['L', 'RGB']
self.mode = mode
super(SVHN, self).__init__(root, SVHN.CLASSES, data_list_file=data_list_file, **kwargs)
def __getitem__(self, index: int) -> Tuple[Any, int]:
"""
Args:
index (int): Index
return (tuple): (image, target) where target is index of the target class.
"""
path, target = self.samples[index]
img = self.loader(path).convert(self.mode)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None and target is not None:
target = self.target_transform(target)
return img, target
class MNISTRGB(MNIST):
def __init__(self, root, **kwargs):
super(MNISTRGB, self).__init__(root, mode='RGB', **kwargs)
class USPSRGB(USPS):
def __init__(self, root, **kwargs):
super(USPSRGB, self).__init__(root, mode='RGB', **kwargs)
class SVHNRGB(SVHN):
def __init__(self, root, **kwargs):
super(SVHNRGB, self).__init__(root, mode='RGB', **kwargs)
================================================
FILE: tllib/vision/datasets/domainnet.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class DomainNet(ImageList):
"""`DomainNet `_ (cleaned version, recommended)
See `Moment Matching for Multi-Source Domain Adaptation `_ for details.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'c'``:clipart, \
``'i'``: infograph, ``'p'``: painting, ``'q'``: quickdraw, ``'r'``: real, ``'s'``: sketch
split (str, optional): The dataset split, supports ``train``, or ``test``.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
clipart/
infograph/
painting/
quickdraw/
real/
sketch/
image_list/
clipart.txt
...
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/bf0fe327e4b046eb89ba/?dl=1"),
("clipart", "clipart.tgz", "https://cloud.tsinghua.edu.cn/f/f0515164a4864220b98b/?dl=1"),
("infograph", "infograph.tgz", "https://cloud.tsinghua.edu.cn/f/98b19d5fc9884109a9cb/?dl=1"),
("painting", "painting.tgz", "https://cloud.tsinghua.edu.cn/f/11285ce9fbd34bb7b28c/?dl=1"),
("quickdraw", "quickdraw.tgz", "https://cloud.tsinghua.edu.cn/f/6faa9efb498b494abf66/?dl=1"),
("real", "real.tgz", "https://cloud.tsinghua.edu.cn/f/17a101842c564959b525/?dl=1"),
("sketch", "sketch.tgz", "https://cloud.tsinghua.edu.cn/f/b305add26e9d47349495/?dl=1"),
]
image_list = {
"c": "clipart",
"i": "infograph",
"p": "painting",
"q": "quickdraw",
"r": "real",
"s": "sketch",
}
CLASSES = ['aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil',
'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball_bat',
'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench',
'bicycle', 'binoculars', 'bird', 'birthday_cake', 'blackberry', 'blueberry', 'book', 'boomerang',
'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket',
'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera',
'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling_fan',
'cello', 'cell_phone', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud',
'coffee_cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile',
'crown', 'cruise_ship', 'cup', 'diamond', 'dishwasher', 'diving_board', 'dog', 'dolphin', 'donut',
'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant',
'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire_hydrant',
'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip_flops', 'floor_lamp', 'flower',
'flying_saucer', 'foot', 'fork', 'frog', 'frying_pan', 'garden', 'garden_hose', 'giraffe', 'goatee',
'golf_club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones',
'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey_puck', 'hockey_stick', 'horse', 'hospital',
'hot_air_balloon', 'hot_dog', 'hot_tub', 'hourglass', 'house', 'house_plant', 'hurricane', 'ice_cream',
'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf',
'leg', 'light_bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster',
'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave',
'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom',
'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paintbrush',
'paint_can', 'palm_tree', 'panda', 'pants', 'paper_clip', 'parachute', 'parrot', 'passport', 'peanut',
'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup_truck', 'picture_frame', 'pig', 'pillow',
'pineapple', 'pizza', 'pliers', 'police_car', 'pond', 'pool', 'popsicle', 'postcard', 'potato',
'power_outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote_control',
'rhinoceros', 'rifle', 'river', 'roller_coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw',
'saxophone', 'school_bus', 'scissors', 'scorpion', 'screwdriver', 'sea_turtle', 'see_saw', 'shark',
'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping_bag',
'smiley_face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer_ball', 'sock', 'speedboat',
'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo',
'stethoscope', 'stitches', 'stop_sign', 'stove', 'strawberry', 'streetlight', 'string_bean', 'submarine',
'suitcase', 'sun', 'swan', 'sweater', 'swing_set', 'sword', 'syringe', 'table', 'teapot', 'teddy-bear',
'telephone', 'television', 'tennis_racquet', 'tent', 'The_Eiffel_Tower', 'The_Great_Wall_of_China',
'The_Mona_Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado',
'tractor', 'traffic_light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 't-shirt',
'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing_machine', 'watermelon', 'waterslide',
'whale', 'wheel', 'windmill', 'wine_bottle', 'wine_glass', 'wristwatch', 'yoga', 'zebra', 'zigzag']
def __init__(self, root: str, task: str, split: Optional[str] = 'train', download: Optional[float] = False, **kwargs):
assert task in self.image_list
assert split in ['train', 'test']
data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split))
print("loading {}".format(data_list_file))
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda args: check_exits(root, args[0]), self.download_list))
super(DomainNet, self).__init__(root, DomainNet.CLASSES, data_list_file=data_list_file, **kwargs)
@classmethod
def domains(cls):
return list(cls.image_list.keys())
================================================
FILE: tllib/vision/datasets/dtd.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class DTD(ImageList):
"""
`The Describable Textures Dataset (DTD) `_ is an \
evolving collection of textural images in the wild, annotated with a series of human-centric attributes, \
inspired by the perceptual properties of textures. \
The task consists in classifying images of textural patterns (47 classes, with 120 training images each). \
Some of the textures are banded, bubbly, meshed, lined, or porous. \
The image size ranges between 300x300 and 640x640 pixels.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/2218bfa61bac46539dd7/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/08fd47d35fc94f36a508/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/15873fe162c343cca8ed/?dl=1"),
("validation", "validation.tgz", "https://cloud.tsinghua.edu.cn/f/75c9ab22ebea4c3b87e7/?dl=1"),
]
CLASSES = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked',
'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy',
'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled',
'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous',
'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped',
'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged']
def __init__(self, root, split, download=False, **kwargs):
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(DTD, self).__init__(root, DTD.CLASSES, os.path.join(root, "image_list", "{}.txt".format(split)), **kwargs)
================================================
FILE: tllib/vision/datasets/eurosat.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class EuroSAT(ImageList):
"""
`EuroSAT `_ dataset consists in classifying \
Sentinel-2 satellite images into 10 different types of land use (Residential, \
Industrial, River, Highway, etc). \
The spatial resolution corresponds to 10 meters per pixel, and the image size \
is 64x64 pixels.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
CLASSES =['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture',
'PermanentCrop', 'Residential', 'River', 'SeaLake']
def __init__(self, root, split='train', download=False, **kwargs):
if download:
download_data(root, "eurosat", "eurosat.tgz", "https://cloud.tsinghua.edu.cn/f/9983d7ab86184d74bb17/?dl=1")
else:
check_exits(root, "eurosat")
split = 'train[:21600]' if split == 'train' else 'train[21600:]'
root = os.path.join(root, "eurosat")
super(EuroSAT, self).__init__(root, EuroSAT.CLASSES, os.path.join(root, "imagelist", "{}.txt".format(split)), **kwargs)
================================================
FILE: tllib/vision/datasets/food101.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from torchvision.datasets.folder import ImageFolder
import os.path as osp
from ._util import download as download_data, check_exits
class Food101(ImageFolder):
"""`Food-101 `_ is a dataset
for fine-grained visual recognition with 101,000 images in 101 food categories.
Args:
root (str): Root directory of dataset.
split (str, optional): The dataset split, supports ``train``, or ``test``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
.. note:: In `root`, there will exist following files after downloading.
::
train/
test/
"""
download_list = [
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/1d7bd727cc1e4ce2bef5/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/7e11992d7495417db32b/?dl=1")
]
def __init__(self, root, split='train', transform=None, download=True):
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(Food101, self).__init__(osp.join(root, split), transform=transform)
self.num_classes = 101
================================================
FILE: tllib/vision/datasets/imagelist.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
import warnings
from typing import Optional, Callable, Tuple, Any, List, Iterable
import bisect
from torch.utils.data.dataset import Dataset, T_co, IterableDataset
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader
class ImageList(datasets.VisionDataset):
"""A generic Dataset class for image classification
Args:
root (str): Root directory of dataset
classes (list[str]): The names of all the classes
data_list_file (str): File to read the image list from.
transform (callable, optional): A function/transform that takes in an PIL image \
and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `data_list_file`, each line has 2 values in the following format.
::
source_dir/dog_xxx.png 0
source_dir/cat_123.png 1
target_dir/dog_xxy.png 0
target_dir/cat_nsdf3.png 1
The first value is the relative path of an image, and the second value is the label of the corresponding image.
If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`.
"""
def __init__(self, root: str, classes: List[str], data_list_file: str,
transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
super().__init__(root, transform=transform, target_transform=target_transform)
self.samples = self.parse_data_file(data_list_file)
self.targets = [s[1] for s in self.samples]
self.classes = classes
self.class_to_idx = {cls: idx
for idx, cls in enumerate(self.classes)}
self.loader = default_loader
self.data_list_file = data_list_file
def __getitem__(self, index: int) -> Tuple[Any, int]:
"""
Args:
index (int): Index
return (tuple): (image, target) where target is index of the target class.
"""
path, target = self.samples[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None and target is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.samples)
def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]:
"""Parse file to data list
Args:
file_name (str): The path of data file
return (list): List of (image path, class_index) tuples
"""
with open(file_name, "r") as f:
data_list = []
for line in f.readlines():
split_line = line.split()
target = split_line[-1]
path = ' '.join(split_line[:-1])
if not os.path.isabs(path):
path = os.path.join(self.root, path)
target = int(target)
data_list.append((path, target))
return data_list
@property
def num_classes(self) -> int:
"""Number of classes"""
return len(self.classes)
@classmethod
def domains(cls):
"""All possible domain in this dataset"""
raise NotImplemented
class MultipleDomainsDataset(Dataset[T_co]):
r"""Dataset as a concatenation of multiple datasets.
This class is useful to assemble different existing datasets.
Args:
datasets (sequence): List of datasets to be concatenated
"""
datasets: List[Dataset[T_co]]
cumulative_sizes: List[int]
@staticmethod
def cumsum(sequence):
r, s = [], 0
for e in sequence:
l = len(e)
r.append(l + s)
s += l
return r
def __init__(self, domains: Iterable[Dataset], domain_names: Iterable[str], domain_ids) -> None:
super(MultipleDomainsDataset, self).__init__()
# Cannot verify that datasets is Sized
assert len(domains) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type]
self.datasets = self.domains = list(domains)
for d in self.domains:
assert not isinstance(d, IterableDataset), "MultipleDomainsDataset does not support IterableDataset"
self.cumulative_sizes = self.cumsum(self.domains)
self.domain_names = domain_names
self.domain_ids = domain_ids
def __len__(self):
return self.cumulative_sizes[-1]
def __getitem__(self, idx):
if idx < 0:
if -idx > len(self):
raise ValueError("absolute value of index should not exceed dataset length")
idx = len(self) + idx
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
if dataset_idx == 0:
sample_idx = idx
else:
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
return self.domains[dataset_idx][sample_idx] + (self.domain_ids[dataset_idx],)
@property
def cummulative_sizes(self):
warnings.warn("cummulative_sizes attribute is renamed to "
"cumulative_sizes", DeprecationWarning, stacklevel=2)
return self.cumulative_sizes
================================================
FILE: tllib/vision/datasets/imagenet_r.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class ImageNetR(ImageList):
"""ImageNet-R Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \
``'D'``: dslr and ``'W'``: webcam.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: You need to put ``train`` directory of ImageNet-1K and ``imagenet_r`` directory of ImageNet-R
manually in `root` directory.
DALIB will only download ImageList automatically.
In `root`, there will exist following files after preparing.
::
train/
n02128385/
...
val/
imagenet-r/
n02128385/
image_list/
imagenet-train.txt
imagenet-r.txt
art.txt
...
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/7786eabd3565409c8c33/?dl=1"),
]
image_list = {
"IN": "image_list/imagenet-train.txt",
"IN-val": "image_list/imagenet-val.txt",
"INR": "image_list/imagenet-r.txt",
"art": "art.txt",
"embroidery": "embroidery.txt",
"misc": "misc.txt",
"sculpture": "sculpture.txt",
"tattoo": "tattoo.txt",
"cartoon": "cartoon.txt",
"graffiti": "graffiti.txt",
"origami": "origami.txt",
"sketch": "sketch.txt",
"toy": "toy.txt",
"deviantart": "deviantart.txt",
"graphic": "graphic.txt",
"painting": "painting.txt",
"sticker": "sticker.txt",
"videogame": "videogame.txt"
}
CLASSES = ['n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214'
, 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677']
def __init__(self, root: str, task: str, split: Optional[str] = 'all', download: Optional[bool] = True, **kwargs):
assert task in self.image_list
assert split in ["train", "val", "all"]
if task == "IN" and split == "val":
task = "IN-val"
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(ImageNetR, self).__init__(root, ImageNetR.CLASSES, data_list_file=data_list_file, **kwargs)
@classmethod
def domains(cls):
return list(cls.image_list.keys())
================================================
FILE: tllib/vision/datasets/imagenet_sketch.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import os
from torchvision.datasets.imagenet import ImageNet
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class ImageNetSketch(ImageList):
"""ImageNet-Sketch Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \
``'D'``: dslr and ``'W'``: webcam.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: You need to put ``train`` directory, ``metabin`` of ImageNet-1K and ``sketch`` directory of ImageNet-Sketch
manually in `root` directory.
DALIB will only download ImageList automatically.
In `root`, there will exist following files after preparing.
::
metabin (from ImageNet)
train/
n02128385/
...
val/
sketch/
n02128385/
image_list/
imagenet-train.txt
sketch.txt
...
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/7786eabd3565409c8c33/?dl=1"),
]
image_list = {
"IN": "image_list/imagenet-train.txt",
"IN-val": "image_list/imagenet-val.txt",
"sketch": "image_list/sketch.txt",
}
def __init__(self, root: str, task: str, split: Optional[str] = 'all', download: Optional[bool] = True, **kwargs):
assert task in self.image_list
assert split in ["train", "val", "all"]
if task == "IN" and split == "val":
task = "IN-val"
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(ImageNetSketch, self).__init__(root, ImageNet(root).classes, data_list_file=data_list_file, **kwargs)
@classmethod
def domains(cls):
return list(cls.image_list.keys())
================================================
FILE: tllib/vision/datasets/keypoint_detection/__init__.py
================================================
from .rendered_hand_pose import RenderedHandPose
from .hand_3d_studio import Hand3DStudio, Hand3DStudioAll
from .freihand import FreiHand
from .surreal import SURREAL
from .lsp import LSP
from .human36m import Human36M
================================================
FILE: tllib/vision/datasets/keypoint_detection/freihand.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import json
import time
import torch
import os
import os.path as osp
from torchvision.datasets.utils import download_and_extract_archive
from ...transforms.keypoint_detection import *
from .keypoint_dataset import Hand21KeypointDataset
from .util import *
""" General util functions. """
def _assert_exist(p):
msg = 'File does not exists: %s' % p
assert os.path.exists(p), msg
def json_load(p):
_assert_exist(p)
with open(p, 'r') as fi:
d = json.load(fi)
return d
def load_db_annotation(base_path, set_name=None):
if set_name is None:
# only training set annotations are released so this is a valid default choice
set_name = 'training'
print('Loading FreiHAND dataset index ...')
t = time.time()
# assumed paths to data containers
k_path = os.path.join(base_path, '%s_K.json' % set_name)
mano_path = os.path.join(base_path, '%s_mano.json' % set_name)
xyz_path = os.path.join(base_path, '%s_xyz.json' % set_name)
# load if exist
K_list = json_load(k_path)
mano_list = json_load(mano_path)
xyz_list = json_load(xyz_path)
# should have all the same length
assert len(K_list) == len(mano_list), 'Size mismatch.'
assert len(K_list) == len(xyz_list), 'Size mismatch.'
print('Loading of %d samples done in %.2f seconds' % (len(K_list), time.time()-t))
return list(zip(K_list, mano_list, xyz_list))
def projectPoints(xyz, K):
""" Project 3D coordinates into image space. """
xyz = np.array(xyz)
K = np.array(K)
uv = np.matmul(K, xyz.T).T
return uv[:, :2] / uv[:, -1:]
""" Dataset related functions. """
def db_size(set_name):
""" Hardcoded size of the datasets. """
if set_name == 'training':
return 32560 # number of unique samples (they exists in multiple 'versions')
elif set_name == 'evaluation':
return 3960
else:
assert 0, 'Invalid choice.'
class sample_version:
gs = 'gs' # green screen
hom = 'hom' # homogenized
sample = 'sample' # auto colorization with sample points
auto = 'auto' # auto colorization without sample points: automatic color hallucination
db_size = db_size('training')
@classmethod
def valid_options(cls):
return [cls.gs, cls.hom, cls.sample, cls.auto]
@classmethod
def check_valid(cls, version):
msg = 'Invalid choice: "%s" (must be in %s)' % (version, cls.valid_options())
assert version in cls.valid_options(), msg
@classmethod
def map_id(cls, id, version):
cls.check_valid(version)
return id + cls.db_size*cls.valid_options().index(version)
class FreiHand(Hand21KeypointDataset):
"""`FreiHand Dataset `_
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.
task (str, optional): The post-processing option to create dataset. Choices include ``'gs'``: green screen \
recording, ``'auto'``: auto colorization without sample points: automatic color hallucination, \
``'sample'``: auto colorization with sample points, ``'hom'``: homogenized, \
and ``'all'``: all hands. Default: 'all'.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and
its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.
image_size (tuple): (width, height) of the image. Default: (256, 256)
heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)
sigma (int): sigma parameter when generate the heatmap. Default: 2
.. note:: In `root`, there will exist following files after downloading.
::
*.json
training/
evaluation/
"""
def __init__(self, root, split='train', task='all', download=True, **kwargs):
if download:
if not osp.exists(osp.join(root, "training")) or not osp.exists(osp.join(root, "evaluation")):
download_and_extract_archive("https://lmb.informatik.uni-freiburg.de/data/freihand/FreiHAND_pub_v2.zip",
download_root=root, filename="FreiHAND_pub_v2.zip", remove_finished=False,
extract_root=root)
assert split in ['train', 'test', 'all']
self.split = split
assert task in ['all', 'gs', 'auto', 'sample', 'hom']
self.task = task
if task == 'all':
samples = self.get_samples(root, 'gs') + self.get_samples(root, 'auto') + self.get_samples(root, 'sample') + self.get_samples(root, 'hom')
else:
samples = self.get_samples(root, task)
random.seed(42)
random.shuffle(samples)
samples_len = len(samples)
samples_split = min(int(samples_len * 0.2), 3200)
if self.split == 'train':
samples = samples[samples_split:]
elif self.split == 'test':
samples = samples[:samples_split]
super(FreiHand, self).__init__(root, samples, **kwargs)
def __getitem__(self, index):
sample = self.samples[index]
image_name = sample['name']
image_path = os.path.join(self.root, image_name)
image = Image.open(image_path)
keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3
keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2
intrinsic_matrix = np.array(sample['intrinsic_matrix'])
Zc = keypoint3d_camera[:, 2]
# Crop the images such that the hand is at the center of the image
# The images will be 1.5 times larger than the hand
# The crop process will change Xc and Yc, leaving Zc with no changes
bounding_box = get_bounding_box(keypoint2d)
w, h = image.size
left, upper, right, lower = scale_box(bounding_box, w, h, 1.5)
image, keypoint2d = crop(image, upper, left, lower - upper, right - left, keypoint2d)
# Change all hands to right hands
if sample['left'] is False:
image, keypoint2d = hflip(image, keypoint2d)
image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)
keypoint2d = data['keypoint2d']
intrinsic_matrix = data['intrinsic_matrix']
keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)
# noramlize 2D pose:
visible = np.ones((self.num_keypoints, ), dtype=np.float32)
visible = visible[:, np.newaxis]
# 2D heatmap
target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)
target = torch.from_numpy(target)
target_weight = torch.from_numpy(target_weight)
# normalize 3D pose:
# put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system
# and make distance between wrist and middle finger MCP joint to be of length 1
keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]
keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))
z = keypoint3d_n[:, 2]
meta = {
'image': image_name,
'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2)
'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3)
'z': z,
}
return image, target, target_weight, meta
def get_samples(self, root, version='gs'):
set = 'training'
# load annotations of this set
db_data_anno = load_db_annotation(root, set)
version_map = {
'gs': sample_version.gs,
'hom': sample_version.hom,
'sample': sample_version.sample,
'auto': sample_version.auto
}
samples = []
for idx in range(db_size(set)):
image_name = os.path.join(set, 'rgb',
'%08d.jpg' % sample_version.map_id(idx, version_map[version]))
mask_name = os.path.join(set, 'mask', '%08d.jpg' % idx)
intrinsic_matrix, mano, keypoint3d = db_data_anno[idx]
keypoint2d = projectPoints(keypoint3d, intrinsic_matrix)
sample = {
'name': image_name,
'mask_name': mask_name,
'keypoint2d': keypoint2d,
'keypoint3d': keypoint3d,
'intrinsic_matrix': intrinsic_matrix,
'left': False
}
samples.append(sample)
return samples
================================================
FILE: tllib/vision/datasets/keypoint_detection/hand_3d_studio.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
import json
import random
from PIL import ImageFile, Image
import torch
import os.path as osp
from .._util import download as download_data, check_exits
from .keypoint_dataset import Hand21KeypointDataset
from .util import *
ImageFile.LOAD_TRUNCATED_IMAGES = True
class Hand3DStudio(Hand21KeypointDataset):
"""`Hand-3d-Studio Dataset `_
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.
task (str, optional): The task to create dataset. Choices include ``'noobject'``: only hands without objects, \
``'object'``: only hands interacting with hands, and ``'all'``: all hands. Default: 'noobject'.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and
its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.
image_size (tuple): (width, height) of the image. Default: (256, 256)
heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)
sigma (int): sigma parameter when generate the heatmap. Default: 2
.. note::
We found that the original H3D image is in high resolution while most part in an image is background,
thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training.
.. note:: In `root`, there will exist following files after downloading.
::
H3D_crop/
annotation.json
part1/
part2/
part3/
part4/
part5/
"""
def __init__(self, root, split='train', task='noobject', download=True, **kwargs):
assert split in ['train', 'test', 'all']
self.split = split
assert task in ['noobject', 'object', 'all']
self.task = task
if download:
download_data(root, "H3D_crop", "H3D_crop.tar", "https://cloud.tsinghua.edu.cn/f/d4e612e44dc04d8eb01f/?dl=1")
else:
check_exits(root, "H3D_crop")
root = osp.join(root, "H3D_crop")
# load labels
annotation_file = os.path.join(root, 'annotation.json')
print("loading from {}".format(annotation_file))
with open(annotation_file) as f:
samples = list(json.load(f))
if task == 'noobject':
samples = [sample for sample in samples if int(sample['without_object']) == 1]
elif task == 'object':
samples = [sample for sample in samples if int(sample['without_object']) == 0]
random.seed(42)
random.shuffle(samples)
samples_len = len(samples)
samples_split = min(int(samples_len * 0.2), 3200)
if split == 'train':
samples = samples[samples_split:]
elif split == 'test':
samples = samples[:samples_split]
super(Hand3DStudio, self).__init__(root, samples, **kwargs)
def __getitem__(self, index):
sample = self.samples[index]
image_name = sample['name']
image_path = os.path.join(self.root, image_name)
image = Image.open(image_path)
keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3
keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2
intrinsic_matrix = np.array(sample['intrinsic_matrix'])
Zc = keypoint3d_camera[:, 2]
image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)
keypoint2d = data['keypoint2d']
intrinsic_matrix = data['intrinsic_matrix']
keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)
# noramlize 2D pose:
visible = np.ones((self.num_keypoints, ), dtype=np.float32)
visible = visible[:, np.newaxis]
# 2D heatmap
target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)
target = torch.from_numpy(target)
target_weight = torch.from_numpy(target_weight)
# normalize 3D pose:
# put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system
# and make distance between wrist and middle finger MCP joint to be of length 1
keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]
keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))
meta = {
'image': image_name,
'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2)
'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3)
}
return image, target, target_weight, meta
class Hand3DStudioAll(Hand3DStudio):
"""
`Hand-3d-Studio Dataset `_
"""
def __init__(self, root, task='all', **kwargs):
super(Hand3DStudioAll, self).__init__(root, task=task, **kwargs)
================================================
FILE: tllib/vision/datasets/keypoint_detection/human36m.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
import json
import tqdm
from PIL import ImageFile
import torch
from .keypoint_dataset import Body16KeypointDataset
from ...transforms.keypoint_detection import *
from .util import *
ImageFile.LOAD_TRUNCATED_IMAGES = True
class Human36M(Body16KeypointDataset):
"""`Human3.6M Dataset `_
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.
Default: ``train``.
task (str, optional): Placeholder.
download (bool, optional): Placeholder.
transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and
its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.
image_size (tuple): (width, height) of the image. Default: (256, 256)
heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)
sigma (int): sigma parameter when generate the heatmap. Default: 2
.. note:: You need to download Human36M manually.
Ensure that there exist following files in the `root` directory before you using this class.
::
annotations/
Human36M_subject11_joint_3d.json
...
images/
.. note::
We found that the original Human3.6M image is in high resolution while most part in an image is background,
thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training.
In `root`, there will exist following files after crop.
::
Human36M_crop/
annotations/
keypoints2d_11.json
...
"""
def __init__(self, root, split='train', task='all', download=True, **kwargs):
assert split in ['train', 'test', 'all']
self.split = split
samples = []
if self.split == 'train':
parts = [1, 5, 6, 7, 8]
elif self.split == 'test':
parts = [9, 11]
else:
parts = [1, 5, 6, 7, 8, 9, 11]
for part in parts:
annotation_file = os.path.join(root, 'annotations/keypoints2d_{}.json'.format(part))
if not os.path.exists(annotation_file):
self.preprocess(part, root)
print("loading", annotation_file)
with open(annotation_file) as f:
samples.extend(json.load(f))
# decrease the number of test samples to decrease the time spent on test
random.seed(42)
if self.split == 'test':
samples = random.choices(samples, k=3200)
super(Human36M, self).__init__(root, samples, **kwargs)
def __getitem__(self, index):
sample = self.samples[index]
image_name = sample['name']
image_path = os.path.join(self.root, "crop_images", image_name)
image = Image.open(image_path)
keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3
keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2
intrinsic_matrix = np.array(sample['intrinsic_matrix'])
Zc = keypoint3d_camera[:, 2]
image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)
keypoint2d = data['keypoint2d']
intrinsic_matrix = data['intrinsic_matrix']
keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)
# noramlize 2D pose:
visible = np.ones((self.num_keypoints, ), dtype=np.float32)
visible = visible[:, np.newaxis]
# 2D heatmap
target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)
target = torch.from_numpy(target)
target_weight = torch.from_numpy(target_weight)
# normalize 3D pose:
# put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system
# and make distance between wrist and middle finger MCP joint to be of length 1
keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]
keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))
meta = {
'image': image_name,
'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2)
'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3)
}
return image, target, target_weight, meta
def preprocess(self, part, root):
body_index = [3, 2, 1, 4, 5, 6, 0, 11, 8, 10, 16, 15, 14, 11, 12, 13]
image_size = 512
print("preprocessing part", part)
camera_json = os.path.join(root, "annotations", "Human36M_subject{}_camera.json".format(part))
data_json = os.path.join(root, "annotations", "Human36M_subject{}_data.json".format(part))
joint_3d_json = os.path.join(root, "annotations", "Human36M_subject{}_joint_3d.json".format(part))
with open(camera_json, "r") as f:
cameras = json.load(f)
with open(data_json, "r") as f:
data = json.load(f)
images = data['images']
with open(joint_3d_json, "r") as f:
joints_3d = json.load(f)
data = []
for i, image_data in enumerate(tqdm.tqdm(images)):
# downsample
if i % 5 == 0:
keypoint3d = np.array(joints_3d[str(image_data["action_idx"])][str(image_data["subaction_idx"])][
str(image_data["frame_idx"])])
keypoint3d = keypoint3d[body_index, :]
keypoint3d[7, :] = 0.5 * (keypoint3d[12, :] + keypoint3d[13, :])
camera = cameras[str(image_data["cam_idx"])]
R, T = np.array(camera["R"]), np.array(camera['t'])[:, np.newaxis]
extrinsic_matrix = np.concatenate([R, T], axis=1)
keypoint3d_camera = np.matmul(extrinsic_matrix, np.hstack(
(keypoint3d, np.ones((keypoint3d.shape[0], 1)))).T) # (3 x NUM_KEYPOINTS)
Z_c = keypoint3d_camera[2:3, :] # 1 x NUM_KEYPOINTS
f, c = np.array(camera["f"]), np.array(camera['c'])
intrinsic_matrix = np.zeros((3, 3))
intrinsic_matrix[0, 0] = f[0]
intrinsic_matrix[1, 1] = f[1]
intrinsic_matrix[0, 2] = c[0]
intrinsic_matrix[1, 2] = c[1]
intrinsic_matrix[2, 2] = 1
keypoint2d = np.matmul(intrinsic_matrix, keypoint3d_camera) # (3 x NUM_KEYPOINTS)
keypoint2d = keypoint2d[0: 2, :] / Z_c
keypoint2d = keypoint2d.T
src_image_path = os.path.join(root, "images", image_data['file_name'])
tgt_image_path = os.path.join(root, "crop_images", image_data['file_name'])
os.makedirs(os.path.dirname(tgt_image_path), exist_ok=True)
image = Image.open(src_image_path)
bounding_box = get_bounding_box(keypoint2d)
w, h = image.size
left, upper, right, lower = scale_box(bounding_box, w, h, 1.5)
image, keypoint2d = crop(image, upper, left, lower-upper+1, right-left+1, keypoint2d)
Z_c = Z_c.T
# Calculate XYZ from uvz
uv1 = np.concatenate([np.copy(keypoint2d), np.ones((16, 1))],
axis=1) # NUM_KEYPOINTS x 3
uv1 = uv1 * Z_c # NUM_KEYPOINTS x 3
keypoint3d_camera = np.matmul(np.linalg.inv(intrinsic_matrix), uv1.T).T
# resize image will change camera intrinsic matrix
w, h = image.size
image = image.resize((image_size, image_size))
image.save(tgt_image_path)
zoom_factor = float(w) / float(image_size)
keypoint2d /= zoom_factor
intrinsic_matrix[0, 0] /= zoom_factor
intrinsic_matrix[1, 1] /= zoom_factor
intrinsic_matrix[0, 2] /= zoom_factor
intrinsic_matrix[1, 2] /= zoom_factor
data.append({
"name": image_data['file_name'],
'keypoint2d': keypoint2d.tolist(),
'keypoint3d': keypoint3d_camera.tolist(),
'intrinsic_matrix': intrinsic_matrix.tolist(),
})
with open(os.path.join(root, "annotations", "keypoints2d_{}.json".format(part)), "w") as f:
json.dump(data, f)
================================================
FILE: tllib/vision/datasets/keypoint_detection/keypoint_dataset.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from abc import ABC
import numpy as np
from torch.utils.data.dataset import Dataset
from webcolors import name_to_rgb
import cv2
class KeypointDataset(Dataset, ABC):
"""A generic dataset class for image keypoint detection
Args:
root (str): Root directory of dataset
num_keypoints (int): Number of keypoints
samples (list): list of data
transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and
its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.
image_size (tuple): (width, height) of the image. Default: (256, 256)
heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)
sigma (int): sigma parameter when generate the heatmap. Default: 2
keypoints_group (dict): a dict that stores the index of different types of keypoints
colored_skeleton (dict): a dict that stores the index and color of different skeleton
"""
def __init__(self, root, num_keypoints, samples, transforms=None, image_size=(256, 256), heatmap_size=(64, 64),
sigma=2, keypoints_group=None, colored_skeleton=None):
self.root = root
self.num_keypoints = num_keypoints
self.samples = samples
self.transforms = transforms
self.image_size = image_size
self.heatmap_size = heatmap_size
self.sigma = sigma
self.keypoints_group = keypoints_group
self.colored_skeleton = colored_skeleton
def __len__(self):
return len(self.samples)
def visualize(self, image, keypoints, filename):
"""Visualize an image with its keypoints, and store the result into a file
Args:
image (PIL.Image):
keypoints (torch.Tensor): keypoints in shape K x 2
filename (str): the name of file to store
"""
assert self.colored_skeleton is not None
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR).copy()
for (_, (line, color)) in self.colored_skeleton.items():
for i in range(len(line) - 1):
start, end = keypoints[line[i]], keypoints[line[i + 1]]
cv2.line(image, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), color=name_to_rgb(color),
thickness=3)
for keypoint in keypoints:
cv2.circle(image, (int(keypoint[0]), int(keypoint[1])), 3, name_to_rgb('black'), 1)
cv2.imwrite(filename, image)
def group_accuracy(self, accuracies):
""" Group the accuracy of K keypoints into different kinds.
Args:
accuracies (list): accuracy of the K keypoints
Returns:
accuracy of ``N=len(keypoints_group)`` kinds of keypoints
"""
grouped_accuracies = dict()
for name, keypoints in self.keypoints_group.items():
grouped_accuracies[name] = sum([accuracies[idx] for idx in keypoints]) / len(keypoints)
return grouped_accuracies
class Body16KeypointDataset(KeypointDataset, ABC):
"""
Dataset with 16 body keypoints.
"""
# TODO: add image
head = (9,)
shoulder = (12, 13)
elbow = (11, 14)
wrist = (10, 15)
hip = (2, 3)
knee = (1, 4)
ankle = (0, 5)
all = (12, 13, 11, 14, 10, 15, 2, 3, 1, 4, 0, 5)
right_leg = (0, 1, 2, 8)
left_leg = (5, 4, 3, 8)
backbone = (8, 9)
right_arm = (10, 11, 12, 8)
left_arm = (15, 14, 13, 8)
def __init__(self, root, samples, **kwargs):
colored_skeleton = {
"right_leg": (self.right_leg, 'yellow'),
"left_leg": (self.left_leg, 'green'),
"backbone": (self.backbone, 'blue'),
"right_arm": (self.right_arm, 'purple'),
"left_arm": (self.left_arm, 'red'),
}
keypoints_group = {
"head": self.head,
"shoulder": self.shoulder,
"elbow": self.elbow,
"wrist": self.wrist,
"hip": self.hip,
"knee": self.knee,
"ankle": self.ankle,
"all": self.all
}
super(Body16KeypointDataset, self).__init__(root, 16, samples, keypoints_group=keypoints_group,
colored_skeleton=colored_skeleton, **kwargs)
class Hand21KeypointDataset(KeypointDataset, ABC):
"""
Dataset with 21 hand keypoints.
"""
# TODO: add image
MCP = (1, 5, 9, 13, 17)
PIP = (2, 6, 10, 14, 18)
DIP = (3, 7, 11, 15, 19)
fingertip = (4, 8, 12, 16, 20)
all = tuple(range(21))
thumb = (0, 1, 2, 3, 4)
index_finger = (0, 5, 6, 7, 8)
middle_finger = (0, 9, 10, 11, 12)
ring_finger = (0, 13, 14, 15, 16)
little_finger = (0, 17, 18, 19, 20)
def __init__(self, root, samples, **kwargs):
colored_skeleton = {
"thumb": (self.thumb, 'yellow'),
"index_finger": (self.index_finger, 'green'),
"middle_finger": (self.middle_finger, 'blue'),
"ring_finger": (self.ring_finger, 'purple'),
"little_finger": (self.little_finger, 'red'),
}
keypoints_group = {
"MCP": self.MCP,
"PIP": self.PIP,
"DIP": self.DIP,
"fingertip": self.fingertip,
"all": self.all
}
super(Hand21KeypointDataset, self).__init__(root, 21, samples, keypoints_group=keypoints_group,
colored_skeleton=colored_skeleton, **kwargs)
================================================
FILE: tllib/vision/datasets/keypoint_detection/lsp.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import scipy.io as scio
import os
from PIL import ImageFile
import torch
from .keypoint_dataset import Body16KeypointDataset
from ...transforms.keypoint_detection import *
from .util import *
from .._util import download as download_data, check_exits
ImageFile.LOAD_TRUNCATED_IMAGES = True
class LSP(Body16KeypointDataset):
"""`Leeds Sports Pose Dataset `_
Args:
root (str): Root directory of dataset
split (str, optional): PlaceHolder.
task (str, optional): Placeholder.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transforms (callable, optional): PlaceHolder.
heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)
sigma (int): sigma parameter when generate the heatmap. Default: 2
.. note:: In `root`, there will exist following files after downloading.
::
lsp/
images/
joints.mat
.. note::
LSP is only used for target domain. Due to the small dataset size, the whole dataset is used
no matter what ``split`` is. Also, the transform is fixed.
"""
def __init__(self, root, split='train', task='all', download=True, image_size=(256, 256), transforms=None, **kwargs):
if download:
download_data(root, "images", "lsp_dataset.zip",
"https://cloud.tsinghua.edu.cn/f/46ea73c89abc46bfb125/?dl=1")
else:
check_exits(root, "lsp")
assert split in ['train', 'test', 'all']
self.split = split
samples = []
annotations = scio.loadmat(os.path.join(root, "joints.mat"))['joints'].transpose((2, 1, 0))
for i in range(0, 2000):
image = "im{0:04d}.jpg".format(i+1)
annotation = annotations[i]
samples.append((image, annotation))
self.joints_index = (0, 1, 2, 3, 4, 5, 13, 13, 12, 13, 6, 7, 8, 9, 10, 11)
self.visible = np.array([1.] * 6 + [0, 0] + [1.] * 8, dtype=np.float32)
normalize = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transforms = Compose([
ResizePad(image_size[0]),
ToTensor(),
normalize
])
super(LSP, self).__init__(root, samples, transforms=transforms, image_size=image_size, **kwargs)
def __getitem__(self, index):
sample = self.samples[index]
image_name = sample[0]
image = Image.open(os.path.join(self.root, "images", image_name))
keypoint2d = sample[1][self.joints_index, :2]
image, data = self.transforms(image, keypoint2d=keypoint2d)
keypoint2d = data['keypoint2d']
visible = self.visible * (1-sample[1][self.joints_index, 2])
visible = visible[:, np.newaxis]
# 2D heatmap
target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)
target = torch.from_numpy(target)
target_weight = torch.from_numpy(target_weight)
meta = {
'image': image_name,
'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2)
'keypoint3d': np.zeros((self.num_keypoints, 3)).astype(keypoint2d.dtype), # (NUM_KEYPOINTS x 3)
}
return image, target, target_weight, meta
================================================
FILE: tllib/vision/datasets/keypoint_detection/rendered_hand_pose.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
import os
import pickle
from .._util import download as download_data, check_exits
from ...transforms.keypoint_detection import *
from .keypoint_dataset import Hand21KeypointDataset
from .util import *
class RenderedHandPose(Hand21KeypointDataset):
"""`Rendered Handpose Dataset `_
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.
task (str, optional): Placeholder.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and
its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.
image_size (tuple): (width, height) of the image. Default: (256, 256)
heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)
sigma (int): sigma parameter when generate the heatmap. Default: 2
.. note:: In `root`, there will exist following files after downloading.
::
RHD_published_v2/
training/
evaluation/
"""
def __init__(self, root, split='train', task='all', download=True, **kwargs):
if download:
download_data(root, "RHD_published_v2", "RHD_v1-1.zip", "https://lmb.informatik.uni-freiburg.de/data/RenderedHandpose/RHD_v1-1.zip")
else:
check_exits(root, "RHD_published_v2")
root = os.path.join(root, "RHD_published_v2")
assert split in ['train', 'test', 'all']
self.split = split
if split == 'all':
samples = self.get_samples(root, 'train') + self.get_samples(root, 'test')
else:
samples = self.get_samples(root, split)
super(RenderedHandPose, self).__init__(
root, samples, **kwargs)
def __getitem__(self, index):
sample = self.samples[index]
image_name = sample['name']
image_path = os.path.join(self.root, image_name)
image = Image.open(image_path)
keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3
keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2
intrinsic_matrix = np.array(sample['intrinsic_matrix'])
Zc = keypoint3d_camera[:, 2]
# Crop the images such that the hand is at the center of the image
# The images will be 1.5 times larger than the hand
# The crop process will change Xc and Yc, leaving Zc with no changes
bounding_box = get_bounding_box(keypoint2d)
w, h = image.size
left, upper, right, lower = scale_box(bounding_box, w, h, 1.5)
image, keypoint2d = crop(image, upper, left, lower - upper, right - left, keypoint2d)
# Change all hands to right hands
if sample['left'] is False:
image, keypoint2d = hflip(image, keypoint2d)
image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)
keypoint2d = data['keypoint2d']
intrinsic_matrix = data['intrinsic_matrix']
keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)
# noramlize 2D pose:
visible = np.array(sample['visible'], dtype=np.float32)
visible = visible[:, np.newaxis]
# 2D heatmap
target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)
target = torch.from_numpy(target)
target_weight = torch.from_numpy(target_weight)
# normalize 3D pose:
# put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system
# and make distance between wrist and middle finger MCP joint to be of length 1
keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]
keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))
z = keypoint3d_n[:, 2]
meta = {
'image': image_name,
'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2)
'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3)
'z': z,
}
return image, target, target_weight, meta
def get_samples(self, root, task, min_size=64):
if task == 'train':
set = 'training'
else:
set = 'evaluation'
# load annotations of this set
with open(os.path.join(root, set, 'anno_%s.pickle' % set), 'rb') as fi:
anno_all = pickle.load(fi)
samples = []
left_hand_index = [0, 4, 3, 2, 1, 8, 7, 6, 5, 12, 11, 10, 9, 16, 15, 14, 13, 20, 19, 18, 17]
right_hand_index = [i+21 for i in left_hand_index]
for sample_id, anno in anno_all.items():
image_name = os.path.join(set, 'color', '%.5d.png' % sample_id)
mask_name = os.path.join(set, 'mask', '%.5d.png' % sample_id)
keypoint2d = anno['uv_vis'][:, :2]
keypoint3d = anno['xyz']
intrinsic_matrix = anno['K']
visible = anno['uv_vis'][:, 2]
left_hand_keypoint2d = keypoint2d[left_hand_index] # NUM_KEYPOINTS x 2
left_box = get_bounding_box(left_hand_keypoint2d)
right_hand_keypoint2d = keypoint2d[right_hand_index] # NUM_KEYPOINTS x 2
right_box = get_bounding_box(right_hand_keypoint2d)
w, h = 320, 320
scaled_left_box = scale_box(left_box, w, h, 1.5)
left, upper, right, lower = scaled_left_box
size = max(right - left, lower - upper)
if size > min_size and np.sum(visible[left_hand_index]) > 16 and area(*intersection(scaled_left_box, right_box)) / area(*scaled_left_box) < 0.3:
sample = {
'name': image_name,
'mask_name': mask_name,
'keypoint2d': left_hand_keypoint2d,
'visible': visible[left_hand_index],
'keypoint3d': keypoint3d[left_hand_index],
'intrinsic_matrix': intrinsic_matrix,
'left': True
}
samples.append(sample)
scaled_right_box = scale_box(right_box, w, h, 1.5)
left, upper, right, lower = scaled_right_box
size = max(right - left, lower - upper)
if size > min_size and np.sum(visible[right_hand_index]) > 16 and area(*intersection(scaled_right_box, left_box)) / area(*scaled_right_box) < 0.3:
sample = {
'name': image_name,
'mask_name': mask_name,
'keypoint2d': right_hand_keypoint2d,
'visible': visible[right_hand_index],
'keypoint3d': keypoint3d[right_hand_index],
'intrinsic_matrix': intrinsic_matrix,
'left': False
}
samples.append(sample)
return samples
================================================
FILE: tllib/vision/datasets/keypoint_detection/surreal.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
import json
from PIL import ImageFile
import torch
from ...transforms.keypoint_detection import *
from .util import *
from .._util import download as download_data, check_exits
from .keypoint_dataset import Body16KeypointDataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
class SURREAL(Body16KeypointDataset):
"""`Surreal Dataset `_
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``.
Default: ``train``.
task (str, optional): Placeholder.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and
its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`.
image_size (tuple): (width, height) of the image. Default: (256, 256)
heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64)
sigma (int): sigma parameter when generate the heatmap. Default: 2
.. note::
We found that the original Surreal image is in high resolution while most part in an image is background,
thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training.
.. note:: In `root`, there will exist following files after downloading.
::
train/
test/
val/
"""
def __init__(self, root, split='train', task='all', download=True, **kwargs):
assert split in ['train', 'test', 'val']
self.split = split
if download:
download_data(root, "train/run0", "train0.tgz", "https://cloud.tsinghua.edu.cn/f/b13604f06ff1445c830a/?dl=1")
download_data(root, "train/run1", "train1.tgz", "https://cloud.tsinghua.edu.cn/f/919aefe2de3541c3b940/?dl=1")
download_data(root, "train/run1", "train2.tgz", "https://cloud.tsinghua.edu.cn/f/34864760ad4945b9bcd6/?dl=1")
download_data(root, "val", "val.tgz", "https://cloud.tsinghua.edu.cn/f/16b20f2e76684f848dc1/?dl=1")
download_data(root, "test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/36c72d86e43540e0a913/?dl=1")
else:
check_exits(root, "train/run0")
check_exits(root, "train/run1")
check_exits(root, "train/run2")
check_exits(root, "val")
check_exits(root, "test")
all_samples = []
for part in [0, 1, 2]:
annotation_file = os.path.join(root, split, 'run{}.json'.format(part))
print("loading", annotation_file)
with open(annotation_file) as f:
samples = json.load(f)
for sample in samples:
sample["image_path"] = os.path.join(root, self.split, 'run{}'.format(part), sample['name'])
all_samples.extend(samples)
random.seed(42)
random.shuffle(all_samples)
samples_len = len(all_samples)
samples_split = min(int(samples_len * 0.2), 3200)
if self.split == 'train':
all_samples = all_samples[samples_split:]
elif self.split == 'test':
all_samples = all_samples[:samples_split]
self.joints_index = (7, 4, 1, 2, 5, 8, 0, 9, 12, 15, 20, 18, 13, 14, 19, 21)
super(SURREAL, self).__init__(root, all_samples, **kwargs)
def __getitem__(self, index):
sample = self.samples[index]
image_name = sample['name']
image_path = sample['image_path']
image = Image.open(image_path)
keypoint3d_camera = np.array(sample['keypoint3d'])[self.joints_index, :] # NUM_KEYPOINTS x 3
keypoint2d = np.array(sample['keypoint2d'])[self.joints_index, :] # NUM_KEYPOINTS x 2
intrinsic_matrix = np.array(sample['intrinsic_matrix'])
Zc = keypoint3d_camera[:, 2]
image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)
keypoint2d = data['keypoint2d']
intrinsic_matrix = data['intrinsic_matrix']
keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc)
# noramlize 2D pose:
visible = np.array([1.] * 16, dtype=np.float32)
visible = visible[:, np.newaxis]
# 2D heatmap
target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size)
target = torch.from_numpy(target)
target_weight = torch.from_numpy(target_weight)
# normalize 3D pose:
# put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system
# and make distance between wrist and middle finger MCP joint to be of length 1
keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :]
keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2))
meta = {
'image': image_name,
'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2)
'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3)
}
return image, target, target_weight, meta
def __len__(self):
return len(self.samples)
================================================
FILE: tllib/vision/datasets/keypoint_detection/util.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import numpy as np
import cv2
def generate_target(joints, joints_vis, heatmap_size, sigma, image_size):
"""Generate heatamap for joints.
Args:
joints: (K, 2)
joints_vis: (K, 1)
heatmap_size: W, H
sigma:
image_size:
Returns:
"""
num_joints = joints.shape[0]
target_weight = np.ones((num_joints, 1), dtype=np.float32)
target_weight[:, 0] = joints_vis[:, 0]
target = np.zeros((num_joints,
heatmap_size[1],
heatmap_size[0]),
dtype=np.float32)
tmp_size = sigma * 3
image_size = np.array(image_size)
heatmap_size = np.array(heatmap_size)
for joint_id in range(num_joints):
feat_stride = image_size / heatmap_size
mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5)
mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5)
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
if mu_x >= heatmap_size[0] or mu_y >= heatmap_size[1] \
or mu_x < 0 or mu_y < 0:
# If not, just return the image as is
target_weight[joint_id] = 0
continue
# Generate gaussian
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
y = x[:, np.newaxis]
x0 = y0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0]
g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1]
# Image range
img_x = max(0, ul[0]), min(br[0], heatmap_size[0])
img_y = max(0, ul[1]), min(br[1], heatmap_size[1])
v = target_weight[joint_id]
if v > 0.5:
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
return target, target_weight
def keypoint2d_to_3d(keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, Zc: np.ndarray):
"""Convert 2D keypoints to 3D keypoints"""
uv1 = np.concatenate([np.copy(keypoint2d), np.ones((keypoint2d.shape[0], 1))], axis=1).T * Zc # 3 x NUM_KEYPOINTS
xyz = np.matmul(np.linalg.inv(intrinsic_matrix), uv1).T # NUM_KEYPOINTS x 3
return xyz
def keypoint3d_to_2d(keypoint3d: np.ndarray, intrinsic_matrix: np.ndarray):
"""Convert 3D keypoints to 2D keypoints"""
keypoint2d = np.matmul(intrinsic_matrix, keypoint3d.T).T # NUM_KEYPOINTS x 3
keypoint2d = keypoint2d[:, :2] / keypoint2d[:, 2:3] # NUM_KEYPOINTS x 2
return keypoint2d
def scale_box(box, image_width, image_height, scale):
"""
Change `box` to a square box.
The side with of the square box will be `scale` * max(w, h)
where w and h is the width and height of the origin box
"""
left, upper, right, lower = box
center_x, center_y = (left + right) / 2, (upper + lower) / 2
w, h = right - left, lower - upper
side_with = min(round(scale * max(w, h)), min(image_width, image_height))
left = round(center_x - side_with / 2)
right = left + side_with - 1
upper = round(center_y - side_with / 2)
lower = upper + side_with - 1
if left < 0:
left = 0
right = side_with - 1
if right >= image_width:
right = image_width - 1
left = image_width - side_with
if upper < 0:
upper = 0
lower = side_with -1
if lower >= image_height:
lower = image_height - 1
upper = image_height - side_with
return left, upper, right, lower
def get_bounding_box(keypoint2d: np.array):
"""Get the bounding box for keypoints"""
left = np.min(keypoint2d[:, 0])
right = np.max(keypoint2d[:, 0])
upper = np.min(keypoint2d[:, 1])
lower = np.max(keypoint2d[:, 1])
return left, upper, right, lower
def visualize_heatmap(image, heatmaps, filename):
image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR).copy()
H, W = heatmaps.shape[1], heatmaps.shape[2]
resized_image = cv2.resize(image, (int(W), int(H)))
heatmaps = heatmaps.mul(255).clamp(0, 255).byte().cpu().numpy()
for k in range(heatmaps.shape[0]):
heatmap = heatmaps[k]
colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
masked_image = colored_heatmap * 0.7 + resized_image * 0.3
cv2.imwrite(filename.format(k), masked_image)
def area(left, upper, right, lower):
return max(right - left + 1, 0) * max(lower - upper + 1, 0)
def intersection(box_a, box_b):
left_a, upper_a, right_a, lower_a = box_a
left_b, upper_b, right_b, lower_b = box_b
return max(left_a, left_b), max(upper_a, upper_b), min(right_a, right_b), min(lower_a, lower_b)
================================================
FILE: tllib/vision/datasets/object_detection/__init__.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import numpy as np
import os
import xml.etree.ElementTree as ET
from detectron2.data import (
MetadataCatalog,
DatasetCatalog,
)
from detectron2.utils.file_io import PathManager
from detectron2.structures import BoxMode
from tllib.vision.datasets._util import download as download_dataset
def parse_root_and_file_name(path):
path_list = path.split('/')
dataset_root = '/'.join(path_list[:-1])
file_name = path_list[-1]
if dataset_root == '':
dataset_root = '.'
return dataset_root, file_name
class VOCBase:
class_names = (
"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat",
"chair", "cow", "diningtable", "dog", "horse", "motorbike", "person",
"pottedplant", "sheep", "sofa", "train", "tvmonitor"
)
def __init__(self, root, split="trainval", year=2007, ext='.jpg', download=True):
self.name = "{}_{}".format(root, split)
self.name = self.name.replace(os.path.sep, "_")
if self.name not in MetadataCatalog.keys():
register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext)
MetadataCatalog.get(self.name).evaluator_type = "pascal_voc"
if download:
dataset_root, file_name = parse_root_and_file_name(root)
download_dataset(dataset_root, file_name, self.archive_name, self.dataset_url)
class VOC2007(VOCBase):
archive_name = 'VOC2007.tgz'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1'
def __init__(self, root):
super(VOC2007, self).__init__(root)
class VOC2012(VOCBase):
archive_name = 'VOC2012.tgz'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/a7e7ab88f727408eaf32/?dl=1'
def __init__(self, root):
super(VOC2012, self).__init__(root, year=2012)
class VOC2007Test(VOCBase):
archive_name = 'VOC2007.tgz'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1'
def __init__(self, root):
super(VOC2007Test, self).__init__(root, year=2007, split='test')
class Clipart(VOCBase):
archive_name = 'clipart.zip'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/c853a66786e2416a8f18/?dl=1'
class VOCPartialBase:
class_names = (
"bicycle", "bird", "car", "cat", "dog", "person",
)
def __init__(self, root, split="trainval", year=2007, ext='.jpg', download=True):
self.name = "{}_{}".format(root, split)
self.name = self.name.replace(os.path.sep, "_")
if self.name not in MetadataCatalog.keys():
register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext)
MetadataCatalog.get(self.name).evaluator_type = "pascal_voc"
if download:
dataset_root, file_name = parse_root_and_file_name(root)
download_dataset(dataset_root, file_name, self.archive_name, self.dataset_url)
class VOC2007Partial(VOCPartialBase):
archive_name = 'VOC2007.tgz'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1'
def __init__(self, root):
super(VOC2007Partial, self).__init__(root)
class VOC2012Partial(VOCPartialBase):
archive_name = 'VOC2012.tgz'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/a7e7ab88f727408eaf32/?dl=1'
def __init__(self, root):
super(VOC2012Partial, self).__init__(root, year=2012)
class VOC2007PartialTest(VOCPartialBase):
archive_name = 'VOC2007.tgz'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1'
def __init__(self, root):
super(VOC2007PartialTest, self).__init__(root, year=2007, split='test')
class WaterColor(VOCPartialBase):
archive_name = 'watercolor.zip'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/9f322fd8496f4766ad93/?dl=1'
def __init__(self, root):
super(WaterColor, self).__init__(root, split='train')
class WaterColorTest(VOCPartialBase):
archive_name = 'watercolor.zip'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/9f322fd8496f4766ad93/?dl=1'
def __init__(self, root):
super(WaterColorTest, self).__init__(root, split='test')
class Comic(VOCPartialBase):
archive_name = 'comic.tar'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/030d7b4b649f46589b2d/?dl=1'
def __init__(self, root):
super(Comic, self).__init__(root, split='train')
class ComicTest(VOCPartialBase):
archive_name = 'comic.tar'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/030d7b4b649f46589b2d/?dl=1'
def __init__(self, root):
super(ComicTest, self).__init__(root, split='test')
class CityscapesBase:
class_names = (
"bicycle", "bus", "car", "motorcycle", "person", "rider", "train", "truck",
)
def __init__(self, root, split="trainval", year=2007, ext='.png'):
self.name = "{}_{}".format(root, split)
self.name = self.name.replace(os.path.sep, "_")
if self.name not in MetadataCatalog.keys():
register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext,
bbox_zero_based=True)
MetadataCatalog.get(self.name).evaluator_type = "pascal_voc"
class Cityscapes(CityscapesBase):
def __init__(self, root):
super(Cityscapes, self).__init__(root, split="trainval")
class CityscapesTest(CityscapesBase):
def __init__(self, root):
super(CityscapesTest, self).__init__(root, split='test')
class FoggyCityscapes(Cityscapes):
pass
class FoggyCityscapesTest(CityscapesTest):
pass
class CityscapesCarBase:
class_names = (
"car",
)
def __init__(self, root, split="trainval", year=2007, ext='.png', bbox_zero_based=True):
self.name = "{}_{}".format(root, split)
self.name = self.name.replace(os.path.sep, "_")
if self.name not in MetadataCatalog.keys():
register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext,
bbox_zero_based=bbox_zero_based)
MetadataCatalog.get(self.name).evaluator_type = "pascal_voc"
class CityscapesCar(CityscapesCarBase):
pass
class CityscapesCarTest(CityscapesCarBase):
def __init__(self, root):
super(CityscapesCarTest, self).__init__(root, split='test')
class Sim10kCar(CityscapesCarBase):
def __init__(self, root):
super(Sim10kCar, self).__init__(root, split='trainval10k', ext='.jpg', bbox_zero_based=False)
class KITTICar(CityscapesCarBase):
def __init__(self, root):
super(KITTICar, self).__init__(root, split='trainval', ext='.jpg', bbox_zero_based=False)
class GTA5(CityscapesBase):
def __init__(self, root):
super(GTA5, self).__init__(root, split="trainval", ext='.jpg')
def load_voc_instances(dirname: str, split: str, class_names, ext='.jpg', bbox_zero_based=False):
"""
Load Pascal VOC detection annotations to Detectron2 format.
Args:
dirname: Contain "Annotations", "ImageSets", "JPEGImages"
split (str): one of "train", "test", "val", "trainval"
class_names: list or tuple of class names
"""
with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f:
fileids = np.loadtxt(f, dtype=np.str)
# Needs to read many small annotation files. Makes sense at local
annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/"))
dicts = []
skip_classes = set()
for fileid in fileids:
anno_file = os.path.join(annotation_dirname, fileid + ".xml")
jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ext)
with PathManager.open(anno_file) as f:
tree = ET.parse(f)
r = {
"file_name": jpeg_file,
"image_id": fileid,
"height": int(tree.findall("./size/height")[0].text),
"width": int(tree.findall("./size/width")[0].text),
}
instances = []
for obj in tree.findall("object"):
cls = obj.find("name").text
if cls not in class_names:
skip_classes.add(cls)
continue
# We include "difficult" samples in training.
# Based on limited experiments, they don't hurt accuracy.
# difficult = int(obj.find("difficult").text)
# if difficult == 1:
# continue
bbox = obj.find("bndbox")
bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]]
# Original annotations are integers in the range [1, W or H]
# Assuming they mean 1-based pixel indices (inclusive),
# a box with annotation (xmin=1, xmax=W) covers the whole image.
# In coordinate space this is represented by (xmin=0, xmax=W)
if bbox_zero_based is False:
bbox[0] -= 1.0
bbox[1] -= 1.0
instances.append(
{"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS}
)
r["annotations"] = instances
dicts.append(r)
print("Skip classes:", list(skip_classes))
return dicts
def register_pascal_voc(name, dirname, split, year, class_names, **kwargs):
DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names, **kwargs))
MetadataCatalog.get(name).set(
thing_classes=list(class_names), dirname=dirname, year=year, split=split
)
================================================
FILE: tllib/vision/datasets/office31.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class Office31(ImageList):
"""Office31 Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \
``'D'``: dslr and ``'W'``: webcam.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
amazon/
images/
backpack/
*.jpg
...
dslr/
webcam/
image_list/
amazon.txt
dslr.txt
webcam.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/2c1dd9fbcaa9455aa4ad/?dl=1"),
("amazon", "amazon.tgz", "https://cloud.tsinghua.edu.cn/f/ec12dfcddade43ab8101/?dl=1"),
("dslr", "dslr.tgz", "https://cloud.tsinghua.edu.cn/f/a41d818ae2f34da7bb32/?dl=1"),
("webcam", "webcam.tgz", "https://cloud.tsinghua.edu.cn/f/8a41009a166e4131adcd/?dl=1"),
]
image_list = {
"A": "image_list/amazon.txt",
"D": "image_list/dslr.txt",
"W": "image_list/webcam.txt"
}
CLASSES = ['back_pack', 'bike', 'bike_helmet', 'bookcase', 'bottle', 'calculator', 'desk_chair', 'desk_lamp',
'desktop_computer', 'file_cabinet', 'headphones', 'keyboard', 'laptop_computer', 'letter_tray',
'mobile_phone', 'monitor', 'mouse', 'mug', 'paper_notebook', 'pen', 'phone', 'printer', 'projector',
'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can']
def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):
assert task in self.image_list
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(Office31, self).__init__(root, Office31.CLASSES, data_list_file=data_list_file, **kwargs)
@classmethod
def domains(cls):
return list(cls.image_list.keys())
================================================
FILE: tllib/vision/datasets/officecaltech.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from typing import Optional
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS, default_loader
from torchvision.datasets.utils import download_and_extract_archive
from ._util import check_exits
class OfficeCaltech(DatasetFolder):
"""Office+Caltech Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \
``'D'``: dslr, ``'W'``:webcam and ``'C'``: caltech.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
amazon/
images/
backpack/
*.jpg
...
dslr/
webcam/
caltech/
image_list/
amazon.txt
dslr.txt
webcam.txt
caltech.txt
"""
directories = {
"A": "amazon",
"D": "dslr",
"W": "webcam",
"C": "caltech"
}
CLASSES = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard',
'laptop_computer', 'monitor', 'mouse', 'mug', 'projector']
def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):
if download:
for dir in self.directories.values():
if not os.path.exists(os.path.join(root, dir)):
download_and_extract_archive(url="https://cloud.tsinghua.edu.cn/f/eea518fa781a41d1b20e/?dl=1",
download_root=os.path.join(root, 'download'),
filename="office-caltech.tgz", remove_finished=False,
extract_root=root)
break
else:
list(map(lambda dir, _: check_exits(root, dir), self.directories.values()))
super(OfficeCaltech, self).__init__(
os.path.join(root, self.directories[task]), default_loader, extensions=IMG_EXTENSIONS, **kwargs)
self.classes = OfficeCaltech.CLASSES
self.class_to_idx = {cls: idx
for idx, clss in enumerate(self.classes)
for cls in clss}
@property
def num_classes(self):
"""Number of classes"""
return len(self.classes)
@classmethod
def domains(cls):
return list(cls.directories.keys())
================================================
FILE: tllib/vision/datasets/officehome.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class OfficeHome(ImageList):
"""`OfficeHome `_ Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'Ar'``: Art, \
``'Cl'``: Clipart, ``'Pr'``: Product and ``'Rw'``: Real_World.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
Art/
Alarm_Clock/*.jpg
...
Clipart/
Product/
Real_World/
image_list/
Art.txt
Clipart.txt
Product.txt
Real_World.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/1b0171a188944313b1f5/?dl=1"),
("Art", "Art.tgz", "https://cloud.tsinghua.edu.cn/f/6a006656b9a14567ade2/?dl=1"),
("Clipart", "Clipart.tgz", "https://cloud.tsinghua.edu.cn/f/ae88aa31d2d7411dad79/?dl=1"),
("Product", "Product.tgz", "https://cloud.tsinghua.edu.cn/f/f219b0ff35e142b3ab48/?dl=1"),
("Real_World", "Real_World.tgz", "https://cloud.tsinghua.edu.cn/f/6c19f3f15bb24ed3951a/?dl=1")
]
image_list = {
"Ar": "image_list/Art.txt",
"Cl": "image_list/Clipart.txt",
"Pr": "image_list/Product.txt",
"Rw": "image_list/Real_World.txt",
}
CLASSES = ['Drill', 'Exit_Sign', 'Bottle', 'Glasses', 'Computer', 'File_Cabinet', 'Shelf', 'Toys', 'Sink',
'Laptop', 'Kettle', 'Folder', 'Keyboard', 'Flipflops', 'Pencil', 'Bed', 'Hammer', 'ToothBrush', 'Couch',
'Bike', 'Postit_Notes', 'Mug', 'Webcam', 'Desk_Lamp', 'Telephone', 'Helmet', 'Mouse', 'Pen', 'Monitor',
'Mop', 'Sneakers', 'Notebook', 'Backpack', 'Alarm_Clock', 'Push_Pin', 'Paper_Clip', 'Batteries', 'Radio',
'Fan', 'Ruler', 'Pan', 'Screwdriver', 'Trash_Can', 'Printer', 'Speaker', 'Eraser', 'Bucket', 'Chair',
'Calendar', 'Calculator', 'Flowers', 'Lamp_Shade', 'Spoon', 'Candles', 'Clipboards', 'Scissors', 'TV',
'Curtains', 'Fork', 'Soda', 'Table', 'Knives', 'Oven', 'Refrigerator', 'Marker']
def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):
assert task in self.image_list
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(OfficeHome, self).__init__(root, OfficeHome.CLASSES, data_list_file=data_list_file, **kwargs)
@classmethod
def domains(cls):
return list(cls.image_list.keys())
================================================
FILE: tllib/vision/datasets/openset/__init__.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from ..imagelist import ImageList
from ..office31 import Office31
from ..officehome import OfficeHome
from ..visda2017 import VisDA2017
from typing import Optional, ClassVar, Sequence
from copy import deepcopy
__all__ = ['Office31', 'OfficeHome', "VisDA2017"]
def open_set(dataset_class: ClassVar, public_classes: Sequence[str],
private_classes: Optional[Sequence[str]] = ()) -> ClassVar:
"""
Convert a dataset into its open-set version.
In other words, those samples which doesn't belong to `private_classes` will be marked as "unknown".
Be aware that `open_set` will change the label number of each category.
Args:
dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be open-set.
public_classes (sequence[str]): A sequence of which categories need to be kept in the open-set dataset.\
Each element of `public_classes` must belong to the `classes` list of `dataset_class`.
private_classes (sequence[str], optional): A sequence of which categories need to be marked as "unknown" \
in the open-set dataset. Each element of `private_classes` must belong to the `classes` list of \
`dataset_class`. Default: ().
Examples::
>>> public_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard']
>>> private_classes = ['laptop_computer', 'monitor', 'mouse', 'mug', 'projector']
>>> # create a open-set dataset class which has classes
>>> # 'back_pack', 'bike', 'calculator', 'headphones', 'keyboard' and 'unknown'.
>>> OpenSetOffice31 = open_set(Office31, public_classes, private_classes)
>>> # create an instance of the open-set dataset
>>> dataset = OpenSetDataset(root="data/office31", task="A")
"""
if not (issubclass(dataset_class, ImageList)):
raise Exception("Only subclass of ImageList can be openset")
class OpenSetDataset(dataset_class):
def __init__(self, **kwargs):
super(OpenSetDataset, self).__init__(**kwargs)
samples = []
all_classes = list(deepcopy(public_classes)) + ["unknown"]
for (path, label) in self.samples:
class_name = self.classes[label]
if class_name in public_classes:
samples.append((path, all_classes.index(class_name)))
elif class_name in private_classes:
samples.append((path, all_classes.index("unknown")))
self.samples = samples
self.classes = all_classes
self.class_to_idx = {cls: idx
for idx, cls in enumerate(self.classes)}
return OpenSetDataset
def default_open_set(dataset_class: ClassVar, source: bool) -> ClassVar:
"""
Default open-set used in some paper.
Args:
dataset_class (class): Dataset class. Currently, dataset_class must be one of
:class:`~tllib.vision.datasets.office31.Office31`, :class:`~tllib.vision.datasets.officehome.OfficeHome`,
:class:`~tllib.vision.datasets.visda2017.VisDA2017`,
source (bool): Whether the dataset is used for source domain or not.
"""
if dataset_class == Office31:
public_classes = Office31.CLASSES[:20]
if source:
private_classes = ()
else:
private_classes = Office31.CLASSES[20:]
elif dataset_class == OfficeHome:
public_classes = sorted(OfficeHome.CLASSES)[:25]
if source:
private_classes = ()
else:
private_classes = sorted(OfficeHome.CLASSES)[25:]
elif dataset_class == VisDA2017:
public_classes = ('bicycle', 'bus', 'car', 'motorcycle', 'train', 'truck')
if source:
private_classes = ()
else:
private_classes = ('aeroplane', 'horse', 'knife', 'person', 'plant', 'skateboard')
else:
raise NotImplementedError("Unknown openset domain adaptation dataset: {}".format(dataset_class.__name__))
return open_set(dataset_class, public_classes, private_classes)
================================================
FILE: tllib/vision/datasets/oxfordflowers.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class OxfordFlowers102(ImageList):
"""
`The Oxford Flowers 102 `_ is a \
consistent of 102 flower categories commonly occurring in the United Kingdom. \
Each class consists of between 40 and 258 images. The images have large scale, \
pose and light variations. In addition, there are categories that have large \
variations within the category and several very similar categories. \
The dataset is divided into a training set, a validation set and a test set. \
The training set and validation set each consist of 10 images per class \
(totalling 1020 images each). \
The test set consists of the remaining 6149 images (minimum 20 per class).
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/161c7b222d6745408201/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/59b6a3fa3dac4404aa3b/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/ec77da479dfb471982fb/?dl=1")
]
CLASSES = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold',
'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon',
"colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower',
'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower',
'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers',
'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist',
'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort',
'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia',
'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion',
'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura',
'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone',
'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus',
'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple',
'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis',
'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen', 'watercress',
'canna lily', 'hippeastrum', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia',
'mallow', 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily']
def __init__(self, root, split='train', download=False, **kwargs):
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(OxfordFlowers102, self).__init__(root, OxfordFlowers102.CLASSES,
os.path.join(root, 'image_list', '{}.txt'.format(split)), **kwargs)
================================================
FILE: tllib/vision/datasets/oxfordpets.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class OxfordIIITPets(ImageList):
"""`The Oxford-IIIT Pets `_ \
is a 37-category pet dataset with roughly 200 images for each class.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
sample_rate (int): The sampling rates to sample random ``training`` images for each category.
Choices include 100, 50, 30, 15. Default: 100.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
train/
test/
image_list/
train_100.txt
train_50.txt
train_30.txt
train_15.txt
test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/8295cfba35b148529bc3/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/89e422c95cb54fb7b0cc/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/dbf7ac10e25b4262b8e5/?dl=1"),
]
image_list = {
"train": "image_list/train_100.txt",
"train100": "image_list/train_100.txt",
"train50": "image_list/train_50.txt",
"train30": "image_list/train_30.txt",
"train15": "image_list/train_15.txt",
"test": "image_list/test.txt",
"test100": "image_list/test.txt",
}
CLASSES = ['Abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'Bengal',
'Birman', 'Bombay', 'boxer', 'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel',
'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger',
'Maine_Coon', 'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug', 'Ragdoll',
'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'Siamese', 'Sphynx',
'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier']
def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,
**kwargs):
if split == 'train':
list_name = 'train' + str(sample_rate)
assert list_name in self.image_list
data_list_file = os.path.join(root, self.image_list[list_name])
else:
data_list_file = os.path.join(root, self.image_list['test'])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(OxfordIIITPets, self).__init__(root, OxfordIIITPets.CLASSES, data_list_file=data_list_file, **kwargs)
================================================
FILE: tllib/vision/datasets/pacs.py
================================================
from typing import Optional
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class PACS(ImageList):
"""`PACS Dataset `_.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \
``'D'``: dslr and ``'W'``: webcam.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
art_painting/
dog/
*.jpg
...
cartoon/
photo/
sketch
image_list/
art_painting.txt
cartoon.txt
photo.txt
sketch.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/603a1fea81f2415ab7e0/?dl=1"),
("art_painting", "art_painting.tgz", "https://cloud.tsinghua.edu.cn/f/46684292e979402b8d87/?dl=1"),
("cartoon", "cartoon.tgz", "https://cloud.tsinghua.edu.cn/f/7bfa413b34ec4f4fa384/?dl=1"),
("photo", "photo.tgz", "https://cloud.tsinghua.edu.cn/f/45f71386a668475d8b42/?dl=1"),
("sketch", "sketch.tgz", "https://cloud.tsinghua.edu.cn/f/4ba559535e4b4b6981e5/?dl=1"),
]
image_list = {
"A": "image_list/art_painting_{}.txt",
"C": "image_list/cartoon_{}.txt",
"P": "image_list/photo_{}.txt",
"S": "image_list/sketch_{}.txt"
}
CLASSES = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
def __init__(self, root: str, task: str, split='all', download: Optional[bool] = True, **kwargs):
assert task in self.image_list
assert split in ["train", "val", "all", "test"]
if split == "test":
split = "all"
data_list_file = os.path.join(root, self.image_list[task].format(split))
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(PACS, self).__init__(root, PACS.CLASSES, data_list_file=data_list_file, target_transform=lambda x: x - 1,
**kwargs)
@classmethod
def domains(cls):
return list(cls.image_list.keys())
================================================
FILE: tllib/vision/datasets/partial/__init__.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from ..imagelist import ImageList
from ..office31 import Office31
from ..officehome import OfficeHome
from ..visda2017 import VisDA2017
from ..officecaltech import OfficeCaltech
from .imagenet_caltech import ImageNetCaltech
from .caltech_imagenet import CaltechImageNet
from tllib.vision.datasets.partial.imagenet_caltech import ImageNetCaltech
from typing import Sequence, ClassVar
__all__ = ['Office31', 'OfficeHome', "VisDA2017", "CaltechImageNet", "ImageNetCaltech"]
def partial(dataset_class: ClassVar, partial_classes: Sequence[str]) -> ClassVar:
"""
Convert a dataset into its partial version.
In other words, those samples which doesn't belong to `partial_classes` will be discarded.
Yet `partial` will not change the label space of `dataset_class`.
Args:
dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be partial.
partial_classes (sequence[str]): A sequence of which categories need to be kept in the partial dataset.\
Each element of `partial_classes` must belong to the `classes` list of `dataset_class`.
Examples::
>>> partial_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard']
>>> # create a partial dataset class
>>> PartialOffice31 = partial(Office31, partial_classes)
>>> # create an instance of the partial dataset
>>> dataset = PartialDataset(root="data/office31", task="A")
"""
if not (issubclass(dataset_class, ImageList)):
raise Exception("Only subclass of ImageList can be partial")
class PartialDataset(dataset_class):
def __init__(self, **kwargs):
super(PartialDataset, self).__init__(**kwargs)
assert all([c in self.classes for c in partial_classes])
samples = []
for (path, label) in self.samples:
class_name = self.classes[label]
if class_name in partial_classes:
samples.append((path, label))
self.samples = samples
self.partial_classes = partial_classes
self.partial_classes_idx = [self.class_to_idx[c] for c in partial_classes]
return PartialDataset
def default_partial(dataset_class: ClassVar) -> ClassVar:
"""
Default partial used in some paper.
Args:
dataset_class (class): Dataset class. Currently, dataset_class must be one of
:class:`~tllib.vision.datasets.office31.Office31`, :class:`~tllib.vision.datasets.officehome.OfficeHome`,
:class:`~tllib.vision.datasets.visda2017.VisDA2017`,
:class:`~tllib.vision.datasets.partial.imagenet_caltech.ImageNetCaltech`
and :class:`~tllib.vision.datasets.partial.caltech_imagenet.CaltechImageNet`.
"""
if dataset_class == Office31:
kept_classes = OfficeCaltech.CLASSES
elif dataset_class == OfficeHome:
kept_classes = sorted(OfficeHome.CLASSES)[:25]
elif dataset_class == VisDA2017:
kept_classes = sorted(VisDA2017.CLASSES)[:6]
elif dataset_class in [ImageNetCaltech, CaltechImageNet]:
kept_classes = dataset_class.CLASSES
else:
raise NotImplementedError("Unknown partial domain adaptation dataset: {}".format(dataset_class.__name__))
return partial(dataset_class, kept_classes)
================================================
FILE: tllib/vision/datasets/partial/caltech_imagenet.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional
import os
from ..imagelist import ImageList
from .._util import download as download_data, check_exits
_CLASSES = ['ak47', 'american flag', 'backpack', 'baseball bat', 'baseball glove', 'basketball hoop', 'bat',
'bathtub', 'bear', 'beer mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai 101',
'boom box', 'bowling ball', 'bowling pin', 'boxing glove', 'brain 101', 'breadmaker', 'buddha 101',
'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car tire',
'cartman', 'cd', 'centipede', 'cereal box', 'chandelier 101', 'chess board', 'chimp', 'chopsticks',
'cockroach', 'coffee mug', 'coffin', 'coin', 'comet', 'computer keyboard', 'computer monitor',
'computer mouse', 'conch', 'cormorant', 'covered wagon', 'cowboy hat', 'crab 101', 'desk globe',
'diamond ring', 'dice', 'dog', 'dolphin 101', 'doorknob', 'drinking straw', 'duck', 'dumb bell',
'eiffel tower', 'electric guitar 101', 'elephant 101', 'elk', 'ewer 101', 'eyeglasses', 'fern',
'fighter jet', 'fire extinguisher', 'fire hydrant', 'fire truck', 'fireworks', 'flashlight',
'floppy disk', 'football helmet', 'french horn', 'fried egg', 'frisbee', 'frog', 'frying pan',
'galaxy', 'gas pump', 'giraffe', 'goat', 'golden gate bridge', 'goldfish', 'golf ball', 'goose',
'gorilla', 'grand piano 101', 'grapes', 'grasshopper', 'guitar pick', 'hamburger', 'hammock',
'harmonica', 'harp', 'harpsichord', 'hawksbill 101', 'head phones', 'helicopter 101', 'hibiscus',
'homer simpson', 'horse', 'horseshoe crab', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass',
'house fly', 'human skeleton', 'hummingbird', 'ibis 101', 'ice cream cone', 'iguana', 'ipod',
'iris', 'jesus christ', 'joy stick', 'kangaroo 101', 'kayak', 'ketch 101', 'killer whale', 'knife',
'ladder', 'laptop 101', 'lathe', 'leopards 101', 'license plate', 'lightbulb', 'light house',
'lightning', 'llama 101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah 101',
'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes 101', 'mountain bike', 'mushroom',
'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm pilot', 'palm tree', 'paperclip',
'paper shredder', 'pci card', 'penguin', 'people', 'pez dispenser', 'photocopier', 'picnic table',
'playing card', 'porcupine', 'pram', 'praying mantis', 'pyramid', 'raccoon', 'radio telescope',
'rainbow', 'refrigerator', 'revolver 101', 'rifle', 'rotary phone', 'roulette wheel', 'saddle',
'saturn', 'school bus', 'scorpion 101', 'screwdriver', 'segway', 'self propelled lawn mower',
'sextant', 'sheet music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake',
'sneaker', 'snowmobile', 'soccer ball', 'socks', 'soda can', 'spaghetti', 'speed boat', 'spider',
'spoon', 'stained glass', 'starfish 101', 'steering wheel', 'stirrups', 'sunflower 101', 'superman',
'sushi', 'swan', 'swiss army knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy bear',
'teepee', 'telephone box', 'tennis ball', 'tennis court', 'tennis racket', 'theodolite', 'toaster',
'tomato', 'tombstone', 'top hat', 'touring bike', 'tower pisa', 'traffic light', 'treadmill',
'triceratops', 'tricycle', 'trilobite 101', 'tripod', 't shirt', 'tuning fork', 'tweezer',
'umbrella 101', 'unicorn', 'vcr', 'video projector', 'washing machine', 'watch 101', 'waterfall',
'watermelon', 'welding mask', 'wheelbarrow', 'windmill', 'wine bottle', 'xylophone', 'yarmulke',
'yo yo', 'zebra', 'airplanes 101', 'car side 101', 'faces easy 101', 'greyhound', 'tennis shoes',
'toad']
class CaltechImageNet(ImageList):
"""Caltech-ImageNet is constructed from `Caltech-256 `_ and
`ImageNet-1K `_ .
They share 84 common classes. Caltech-ImageNet keeps all classes of Caltech-256.
The label is based on the Caltech256 (class 0-255) . The private classes of ImageNet-1K is discarded.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \
``'I'``: ImageNet-1K validation set.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory
since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically.
In `root`, there will exist following files after downloading.
::
train/
n01440764/
...
val/
256_ObjectCategories/
001.ak47/
...
image_list/
caltech_256_list.txt
...
"""
image_list = {
"C": "image_list/caltech_256_list.txt",
"I": "image_list/imagenet_val_84_list.txt",
}
CLASSES = _CLASSES
def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):
assert task in self.image_list
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), download_list))
if not os.path.exists(os.path.join(root, 'val')):
print("Please put train and val directory of ImageNet-1K manually under {} "
"since ImageNet-1K is no longer publicly accessible.".format(root))
exit(-1)
super(CaltechImageNet, self).__init__(root, CaltechImageNet.CLASSES, data_list_file=data_list_file, **kwargs)
class CaltechImageNetUniversal(ImageList):
"""Caltech-ImageNet-Universal is constructed from `Caltech-256 `_
and `ImageNet-1K `_ .
They share 84 common classes. Caltech-ImageNet keeps all classes of Caltech-256.
The label is based on the Caltech256 (class 0-255) . The private classes of ImageNet-1K is grouped into class 256 ("unknown").
Thus, CaltechImageNetUniversal has 257 classes in total.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \
``'I'``: ImageNet-1K validation set.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory
since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically.
In `root`, there will exist following files after downloading.
::
train/
n01440764/
...
val/
256_ObjectCategories/
001.ak47/
...
image_list/
caltech_256_list.txt
...
"""
image_list = {
"C": "image_list/caltech_256_list.txt",
"I": "image_list/imagenet_val_85_list.txt",
}
CLASSES = _CLASSES + ['unknown']
def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):
assert task in self.image_list
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), download_list))
if not os.path.exists(os.path.join(root, 'val')):
print("Please put train and val directory of ImageNet-1K manually under {} "
"since ImageNet-1K is no longer publicly accessible.".format(root))
exit(-1)
super(CaltechImageNetUniversal, self).__init__(root, CaltechImageNetUniversal.CLASSES,
data_list_file=data_list_file, **kwargs)
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/a0d7ea37026946f98965/?dl=1"),
("256_ObjectCategories", "256_ObjectCategories.tar",
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar"),
]
================================================
FILE: tllib/vision/datasets/partial/imagenet_caltech.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from typing import Optional
from ..imagelist import ImageList
from .._util import download as download_data, check_exits
_CLASSES = [c[0] for c in [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'),
('great white shark', 'white shark', 'man-eater', 'man-eating shark', 'Carcharodon carcharias'),
('tiger shark', 'Galeocerdo cuvieri'), ('hammerhead', 'hammerhead shark'),
('electric ray', 'crampfish', 'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',),
('ostrich', 'Struthio camelus'), ('brambling', 'Fringilla montifringilla'),
('goldfinch', 'Carduelis carduelis'), ('house finch', 'linnet', 'Carpodacus mexicanus'),
('junco', 'snowbird'), ('indigo bunting', 'indigo finch', 'indigo bird', 'Passerina cyanea'),
('robin', 'American robin', 'Turdus migratorius'), ('bulbul',), ('jay',), ('magpie',),
('chickadee',), ('water ouzel', 'dipper'), ('kite',),
('bald eagle', 'American eagle', 'Haliaeetus leucocephalus'), ('vulture',),
('great grey owl', 'great gray owl', 'Strix nebulosa'),
('European fire salamander', 'Salamandra salamandra'), ('common newt', 'Triturus vulgaris'),
('eft',), ('spotted salamander', 'Ambystoma maculatum'),
('axolotl', 'mud puppy', 'Ambystoma mexicanum'), ('bullfrog', 'Rana catesbeiana'),
('tree frog', 'tree-frog'),
('tailed frog', 'bell toad', 'ribbed toad', 'tailed toad', 'Ascaphus trui'),
('loggerhead', 'loggerhead turtle', 'Caretta caretta'),
('leatherback turtle', 'leatherback', 'leathery turtle', 'Dermochelys coriacea'),
('mud turtle',), ('terrapin',), ('box turtle', 'box tortoise'), ('banded gecko',),
('common iguana', 'iguana', 'Iguana iguana'),
('American chameleon', 'anole', 'Anolis carolinensis'), ('whiptail', 'whiptail lizard'),
('agama',), ('frilled lizard', 'Chlamydosaurus kingi'), ('alligator lizard',),
('Gila monster', 'Heloderma suspectum'), ('green lizard', 'Lacerta viridis'),
('African chameleon', 'Chamaeleo chamaeleon'),
('Komodo dragon', 'Komodo lizard', 'dragon lizard', 'giant lizard', 'Varanus komodoensis'),
('African crocodile', 'Nile crocodile', 'Crocodylus niloticus'),
('American alligator', 'Alligator mississipiensis'), ('triceratops',),
('thunder snake', 'worm snake', 'Carphophis amoenus'),
('ringneck snake', 'ring-necked snake', 'ring snake'),
('hognose snake', 'puff adder', 'sand viper'), ('green snake', 'grass snake'),
('king snake', 'kingsnake'), ('garter snake', 'grass snake'), ('water snake',),
('vine snake',), ('night snake', 'Hypsiglena torquata'),
('boa constrictor', 'Constrictor constrictor'), ('rock python', 'rock snake', 'Python sebae'),
('Indian cobra', 'Naja naja'), ('green mamba',), ('sea snake',),
('horned viper', 'cerastes', 'sand viper', 'horned asp', 'Cerastes cornutus'),
('diamondback', 'diamondback rattlesnake', 'Crotalus adamanteus'),
('sidewinder', 'horned rattlesnake', 'Crotalus cerastes'), ('trilobite',),
('harvestman', 'daddy longlegs', 'Phalangium opilio'), ('scorpion',),
('black and gold garden spider', 'Argiope aurantia'), ('barn spider', 'Araneus cavaticus'),
('garden spider', 'Aranea diademata'), ('black widow', 'Latrodectus mactans'), ('tarantula',),
('wolf spider', 'hunting spider'), ('tick',), ('centipede',), ('black grouse',),
('ptarmigan',), ('ruffed grouse', 'partridge', 'Bonasa umbellus'),
('prairie chicken', 'prairie grouse', 'prairie fowl'), ('peacock',), ('quail',),
('partridge',), ('African grey', 'African gray', 'Psittacus erithacus'), ('macaw',),
('sulphur-crested cockatoo', 'Kakatoe galerita', 'Cacatua galerita'), ('lorikeet',),
('coucal',), ('bee eater',), ('hornbill',), ('hummingbird',), ('jacamar',), ('toucan',),
('drake',), ('red-breasted merganser', 'Mergus serrator'), ('goose',),
('black swan', 'Cygnus atratus'), ('tusker',), ('echidna', 'spiny anteater', 'anteater'), (
'platypus', 'duckbill', 'duckbilled platypus', 'duck-billed platypus',
'Ornithorhynchus anatinus'), ('wallaby', 'brush kangaroo'),
('koala', 'koalabear', 'kangaroo bear', 'native bear', 'Phascolarctos cinereus'), ('wombat',),
('jellyfish',), ('sea anemone', 'anemone'), ('brain coral',), ('flatworm', 'platyhelminth'),
('nematode', 'nematode worm', 'roundworm'), ('conch',), ('snail',), ('slug',),
('sea slug', 'nudibranch'), ('chiton', 'coat-of-mail shell', 'sea cradle', 'polyplacophore'),
('chambered nautilus', 'pearly nautilus', 'nautilus'), ('Dungeness crab', 'Cancer magister'),
('rock crab', 'Cancer irroratus'), ('fiddler crab',), (
'king crab', 'Alaska crab', 'Alaskan king crab', 'Alaska king crab',
'Paralithodes camtschatica'),
('American lobster', 'Northernlobster', 'Maine lobster', 'Homarus americanus'),
('spiny lobster', 'langouste', 'rock lobster', 'crawfish', 'crayfish', 'sea crawfish'),
('crayfish', 'crawfish', 'crawdad', 'crawdaddy'), ('hermit crab',), ('isopod',),
('white stork', 'Ciconia ciconia'), ('black stork', 'Ciconia nigra'), ('spoonbill',),
('flamingo',), ('little blue heron', 'Egretta caerulea'),
('American egret', 'great white heron', 'Egretta albus'), ('bittern',), ('crane',),
('limpkin', 'Aramus pictus'), ('European gallinule', 'Porphyrio porphyrio'),
('American coot', 'marsh hen', 'mud hen', 'water hen', 'Fulica americana'), ('bustard',),
('ruddy turnstone', 'Arenaria interpres'),
('red-backed sandpiper', 'dunlin', 'Erolia alpina'), ('redshank', 'Tringa totanus'),
('dowitcher',), ('oystercatcher', 'oyster catcher'), ('pelican',),
('king penguin', 'Aptenodytes patagonica'), ('albatross', 'mollymawk'),
('grey whale', 'gray whale', 'devilfish', 'Eschrichtius gibbosus', 'Eschrichtius robustus'),
('killer whale', 'killer', 'orca', 'grampus', 'sea wolf', 'Orcinusorca'),
('dugong', 'Dugong dugon'), ('sea lion',), ('Chihuahua',), ('Japanese spaniel',),
('Maltese dog', 'Maltese terrier', 'Maltese'), ('Pekinese', 'Pekingese', 'Peke'),
('Shih-Tzu',), ('Blenheim spaniel',), ('papillon',), ('toy terrier',),
('Rhodesian ridgeback',), ('Afghan hound', 'Afghan'), ('basset', 'basset hound'), ('beagle',),
('bloodhound', 'sleuthhound'), ('bluetick',), ('black-and-tan coonhound',),
('Walker hound', 'Walker foxhound'), ('English foxhound',), ('redbone',),
('borzoi', 'Russian wolfhound'), ('Irish wolfhound',), ('Italian greyhound',), ('whippet',),
('Ibizan hound', 'Ibizan Podenco'), ('Norwegian elkhound', 'elkhound'),
('otterhound', 'otter hound'), ('Saluki', 'gazelle hound'),
('Scottish deerhound', 'deerhound'), ('Weimaraner',),
('Staffordshire bullterrier', 'Staffordshire bull terrier'), (
'American Staffordshire terrier', 'Staffordshire terrier', 'American pit bull terrier',
'pit bull terrier'), ('Bedlington terrier',), ('Border terrier',),
('Kerry blue terrier',),
('Irish terrier',), ('Norfolkterrier',), ('Norwich terrier',), ('Yorkshire terrier',),
('wire-haired fox terrier',), ('Lakeland terrier',), ('Sealyham terrier', 'Sealyham'),
('Airedale', 'Airedale terrier'), ('cairn', 'cairn terrier'), ('Australian terrier',),
('Dandie Dinmont', 'Dandie Dinmont terrier'), ('Boston bull', 'Boston terrier'),
('miniature schnauzer',), ('giant schnauzer',), ('standard schnauzer',),
('Scotch terrier', 'Scottish terrier', 'Scottie'), ('Tibetan terrier', 'chrysanthemum dog'),
('silky terrier', 'Sydney silky'), ('soft-coated wheaten terrier',),
('West Highland white terrier',), ('Lhasa', 'Lhasa apso'), ('flat-coated retriever',),
('curly-coated retriever',), ('golden retriever',), ('Labrador retriever',),
('Chesapeake Bay retriever',), ('German short-haired pointer',),
('vizsla', 'Hungarian pointer'), ('English setter',), ('Irish setter', 'red setter'),
('Gordon setter',), ('Brittany spaniel',), ('clumber', 'clumber spaniel'),
('English springer', 'English springer spaniel'), ('Welsh springer spaniel',),
('cocker spaniel', 'English cocker spaniel', 'cocker'), ('Sussex spaniel',),
('Irish water spaniel',), ('kuvasz',), ('schipperke',), ('groenendael',), ('malinois',),
('briard',), ('kelpie',), ('komondor',), ('Old English sheepdog', 'bobtail'),
('Shetland sheepdog', 'Shetland sheep dog', 'Shetland'), ('collie',), ('Border collie',),
('Bouvier des Flandres', 'Bouviers des Flandres'), ('Rottweiler',),
('German shepherd', 'German shepherd dog', 'German police dog', 'alsatian'),
('Doberman', 'Doberman pinscher'), ('miniature pinscher',), ('Greater Swiss Mountain dog',),
('Bernese mountain dog',), ('Appenzeller',), ('EntleBucher',), ('boxer',), ('bull mastiff',),
('Tibetan mastiff',), ('French bulldog',), ('Great Dane',), ('Saint Bernard', 'St Bernard'),
('Eskimo dog', 'husky'), ('malamute', 'malemute', 'Alaskan malamute'), ('Siberian husky',),
('dalmatian', 'coach dog', 'carriage dog'),
('affenpinscher', 'monkey pinscher', 'monkey dog'), ('basenji',), ('pug', 'pug-dog'),
('Leonberg',), ('Newfoundland', 'Newfoundland dog'), ('Great Pyrenees',),
('Samoyed', 'Samoyede'), ('Pomeranian',), ('chow', 'chow chow'), ('keeshond',),
('Brabancon griffon',), ('Pembroke', 'Pembroke Welsh corgi'),
('Cardigan', 'Cardigan Welsh corgi'), ('toy poodle',), ('miniature poodle',),
('standard poodle',), ('Mexican hairless',),
('timber wolf', 'grey wolf', 'gray wolf', 'Canis lupus'),
('white wolf', 'Arctic wolf', 'Canis lupus tundrarum'),
('red wolf', 'maned wolf', 'Canis rufus', 'Canis niger'),
('coyote', 'prairie wolf', 'brush wolf', 'Canis latrans'),
('dingo', 'warrigal', 'warragal', 'Canis dingo'), ('dhole', 'Cuon alpinus'),
('African hunting dog', 'hyena dog', 'Cape hunting dog', 'Lycaon pictus'),
('hyena', 'hyaena'), ('red fox', 'Vulpes vulpes'), ('kit fox', 'Vulpes macrotis'),
('Arctic fox', 'white fox', 'Alopex lagopus'),
('grey fox', 'gray fox', 'Urocyon cinereoargenteus'), ('tabby', 'tabby cat'), ('tiger cat',),
('Persian cat',), ('Siamese cat', 'Siamese'), ('Egyptian cat',),
('cougar', 'puma', 'catamount', 'mountain lion', 'painter', 'panther', 'Felis concolor'),
('lynx', 'catamount'), ('leopard', 'Panthera pardus'),
('snow leopard', 'ounce', 'Panthera uncia'),
('jaguar', 'panther', 'Panthera onca', 'Felis onca'),
('lion', 'king of beasts', 'Panthera leo'), ('tiger', 'Panthera tigris'),
('cheetah', 'chetah', 'Acinonyx jubatus'), ('brown bear', 'bruin', 'Ursus arctos'),
('American black bear', 'black bear', 'Ursus americanus', 'Euarctos americanus'),
('ice bear', 'polar bear', 'Ursus Maritimus', 'Thalarctos maritimus'),
('sloth bear', 'Melursus ursinus', 'Ursus ursinus'), ('mongoose',), ('meerkat', 'mierkat'),
('tiger beetle',), ('ladybug', 'ladybeetle', 'lady beetle', 'ladybird', 'ladybird beetle'),
('ground beetle', 'carabid beetle'), ('long-horned beetle', 'longicorn', 'longicorn beetle'),
('leaf beetle', 'chrysomelid'), ('dung beetle',), ('rhinoceros beetle',), ('weevil',),
('fly',), ('bee',), ('ant', 'emmet', 'pismire'), ('grasshopper', 'hopper'), ('cricket',),
('walking stick', 'walkingstick', 'stick insect'), ('cockroach', 'roach'),
('mantis', 'mantid'), ('cicada', 'cicala'), ('leafhopper',), ('lacewing', 'lacewing fly'), (
'dragonfly', 'darning needle', "devil's darning needle", 'sewing needle', 'snake feeder',
'snake doctor', 'mosquito hawk', 'skeeter hawk'), ('damselfly',), ('admiral',),
('ringlet', 'ringlet butterfly'),
('monarch', 'monarch butterfly', 'milkweed butterfly', 'Danaus plexippus'),
('cabbage butterfly',), ('sulphur butterfly', 'sulfur butterfly'),
('lycaenid', 'lycaenid butterfly'), ('starfish', 'sea star'), ('sea urchin',),
('sea cucumber', 'holothurian'), ('wood rabbit', 'cottontail', 'cottontail rabbit'),
('hare',), ('Angora', 'Angora rabbit'), ('hamster',), ('porcupine', 'hedgehog'),
('fox squirrel', 'eastern fox squirrel', 'Sciurus niger'), ('marmot',), ('beaver',),
('guinea pig', 'Cavia cobaya'), ('sorrel',), ('zebra',),
('hog', 'pig', 'grunter', 'squealer', 'Sus scrofa'), ('wild boar', 'boar', 'Sus scrofa'),
('warthog',), ('hippopotamus', 'hippo', 'river horse', 'Hippopotamus amphibius'), ('ox',),
('water buffalo', 'water ox', 'Asiatic buffalo', 'Bubalus bubalis'), ('bison',),
('ram', 'tup'), (
'bighorn', 'bighorn sheep', 'cimarron', 'Rocky Mountain bighorn', 'Rocky Mountain sheep',
'Ovis canadensis'), ('ibex', 'Capra ibex'), ('hartebeest',),
('impala', 'Aepyceros melampus'),
('gazelle',), ('Arabian camel', 'dromedary', 'Camelus dromedarius'), ('llama',), ('weasel',),
('mink',), ('polecat', 'fitch', 'foulmart', 'foumart', 'Mustela putorius'),
('black-footed ferret', 'ferret', 'Mustela nigripes'), ('otter',),
('skunk', 'polecat', 'wood pussy'), ('badger',), ('armadillo',),
('three-toed sloth', 'ai', 'Bradypus tridactylus'),
('orangutan', 'orang', 'orangutang', 'Pongo pygmaeus'), ('gorilla', 'Gorilla gorilla'),
('chimpanzee', 'chimp', 'Pan troglodytes'), ('gibbon', 'Hylobates lar'),
('siamang', 'Hylobates syndactylus', 'Symphalangus syndactylus'), ('guenon', 'guenon monkey'),
('patas', 'hussar monkey', 'Erythrocebus patas'), ('baboon',), ('macaque',), ('langur',),
('colobus', 'colobus monkey'), ('proboscis monkey', 'Nasalis larvatus'), ('marmoset',),
('capuchin', 'ringtail', 'Cebus capucinus'), ('howler monkey', 'howler'),
('titi', 'titi monkey'), ('spider monkey', 'Ateles geoffroyi'),
('squirrel monkey', 'Saimiri sciureus'),
('Madagascar cat', 'ring-tailed lemur', 'Lemur catta'),
('indri', 'indris', 'Indri indri', 'Indri brevicaudatus'),
('Indian elephant', 'Elephas maximus'), ('African elephant', 'Loxodonta africana'),
('lesser panda', 'red panda', 'panda', 'bear cat', 'cat bear', 'Ailurus fulgens'),
('giant panda', 'panda', 'panda bear', 'coon bear', 'Ailuropoda melanoleuca'),
('barracouta', 'snoek'), ('eel',),
('coho', 'cohoe', 'coho salmon', 'blue jack', 'silver salmon', 'Oncorhynchus kisutch'),
('rock beauty', 'Holocanthus tricolor'), ('anemone fish',), ('sturgeon',),
('gar', 'garfish', 'garpike', 'billfish', 'Lepisosteus osseus'), ('lionfish',),
('puffer', 'pufferfish', 'blowfish', 'globefish'), ('abacus',), ('abaya',),
('academic gown', 'academic robe', "judge's robe"),
('accordion', 'piano accordion', 'squeeze box'), ('acoustic guitar',),
('aircraft carrier', 'carrier', 'flattop', 'attack aircraft carrier'), ('airliner',),
('airship', 'dirigible'), ('altar',), ('ambulance',), ('amphibian', 'amphibious vehicle'),
('analog clock',), ('apiary', 'bee house'), ('apron',), (
'ashcan', 'trash can', 'garbage can', 'wastebin', 'ash bin', 'ash-bin', 'ashbin',
'dustbin',
'trash barrel', 'trash bin'), ('assault rifle', 'assault gun'),
('backpack', 'back pack', 'knapsack', 'packsack', 'rucksack', 'haversack'),
('bakery', 'bakeshop', 'bakehouse'), ('balance beam', 'beam'), ('balloon',),
('ballpoint', 'ballpoint pen', 'ballpen', 'Biro'), ('Band Aid',), ('banjo',),
('bannister', 'banister', 'balustrade', 'balusters', 'handrail'), ('barbell',),
('barber chair',), ('barbershop',), ('barn',), ('barometer',), ('barrel', 'cask'),
('barrow', 'garden cart', 'lawn cart', 'wheelbarrow'), ('baseball',), ('basketball',),
('bassinet',), ('bassoon',), ('bathing cap', 'swimming cap'), ('bath towel',),
('bathtub', 'bathing tub', 'bath', 'tub'), (
'beach wagon', 'station wagon', 'wagon', 'estate car', 'beach waggon', 'station waggon',
'waggon'), ('beacon', 'lighthouse', 'beacon light', 'pharos'), ('beaker',),
('bearskin', 'busby', 'shako'), ('beer bottle',), ('beer glass',), ('bell cote', 'bell cot'),
('bib',), ('bicycle-built-for-two', 'tandem bicycle', 'tandem'), ('bikini', 'two-piece'),
('binder', 'ring-binder'), ('binoculars', 'field glasses', 'opera glasses'), ('birdhouse',),
('boathouse',), ('bobsled', 'bobsleigh', 'bob'), ('bolo tie', 'bolo', 'bola tie', 'bola'),
('bonnet', 'poke bonnet'), ('bookcase',), ('bookshop', 'bookstore', 'bookstall'),
('bottlecap',), ('bow',), ('bow tie', 'bow-tie', 'bowtie'),
('brass', 'memorial tablet', 'plaque'), ('brassiere', 'bra', 'bandeau'),
('breakwater', 'groin', 'groyne', 'mole', 'bulwark', 'seawall', 'jetty'),
('breastplate', 'aegis', 'egis'), ('broom',), ('bucket', 'pail'), ('buckle',),
('bulletproof vest',), ('bullet train', 'bullet'), ('butcher shop', 'meat market'),
('cab', 'hack', 'taxi', 'taxicab'), ('caldron', 'cauldron'), ('candle', 'taper', 'wax light'),
('cannon',), ('canoe',), ('can opener', 'tin opener'), ('cardigan',), ('car mirror',),
('carousel', 'carrousel', 'merry-go-round', 'roundabout', 'whirligig'),
("carpenter's kit", 'tool kit'), ('carton',), ('car wheel',), (
'cash machine', 'cash dispenser', 'automated teller machine', 'automatic teller machine',
'automated teller', 'automatic teller', 'ATM'), ('cassette',), ('cassette player',),
('castle',), ('catamaran',), ('CD player',), ('cello', 'violoncello'),
('cellular telephone', 'cellular phone', 'cellphone', 'cell', 'mobile phone'), ('chain',),
('chainlink fence',), (
'chain mail', 'ring mail', 'mail', 'chain armor', 'chain armour', 'ring armor',
'ring armour'), ('chain saw', 'chainsaw'), ('chest',), ('chiffonier', 'commode'),
('chime', 'bell', 'gong'), ('china cabinet', 'china closet'), ('Christmas stocking',),
('church', 'church building'),
('cinema', 'movie theater', 'movie theatre', 'movie house', 'picture palace'),
('cleaver', 'meat cleaver', 'chopper'), ('cliff dwelling',), ('cloak',),
('clog', 'geta', 'patten', 'sabot'), ('cocktail shaker',), ('coffee mug',), ('coffeepot',),
('coil', 'spiral', 'volute', 'whorl', 'helix'), ('combination lock',),
('computer keyboard', 'keypad'), ('confectionery', 'confectionary', 'candy store'),
('container ship', 'containership', 'container vessel'), ('convertible',),
('corkscrew', 'bottle screw'), ('cornet', 'horn', 'trumpet', 'trump'), ('cowboy boot',),
('cowboy hat', 'ten-gallon hat'), ('cradle',), ('crane',), ('crash helmet',), ('crate',),
('crib', 'cot'), ('Crock Pot',), ('croquet ball',), ('crutch',), ('cuirass',),
('dam', 'dike', 'dyke'), ('desk',), ('desktop computer',), ('dial telephone', 'dial phone'),
('diaper', 'nappy', 'napkin'), ('digital clock',), ('digital watch',),
('dining table', 'board'), ('dishrag', 'dishcloth'),
('dishwasher', 'dish washer', 'dishwashing machine'), ('disk brake', 'disc brake'),
('dock', 'dockage', 'docking facility'), ('dogsled', 'dog sled', 'dog sleigh'), ('dome',),
('doormat', 'welcome mat'), ('drilling platform', 'offshore rig'),
('drum', 'membranophone', 'tympan'), ('drumstick',), ('dumbbell',), ('Dutch oven',),
('electric fan', 'blower'), ('electric guitar',), ('electric locomotive',),
('entertainment center',), ('envelope',), ('espresso maker',), ('face powder',),
('feather boa', 'boa'), ('file', 'file cabinet', 'filing cabinet'), ('fireboat',),
('fire engine', 'fire truck'), ('fire screen', 'fireguard'), ('flagpole', 'flagstaff'),
('flute', 'transverse flute'), ('folding chair',), ('football helmet',), ('forklift',),
('fountain',), ('fountain pen',), ('four-poster',), ('freight car',), ('French horn', 'horn'),
('frying pan', 'frypan', 'skillet'), ('fur coat',), ('garbage truck', 'dustcart'),
('gasmask', 'respirator', 'gas helmet'),
('gas pump', 'gasoline pump', 'petrol pump', 'island dispenser'), ('goblet',), ('go-kart',),
('golf ball',), ('golfcart', 'golf cart'), ('gondola',), ('gong', 'tam-tam'), ('gown',),
('grand piano', 'grand'), ('greenhouse', 'nursery', 'glasshouse'),
('grille', 'radiator grille'), ('grocery store', 'grocery', 'food market', 'market'),
('guillotine',), ('hair slide',), ('hair spray',), ('half track',), ('hammer',), ('hamper',),
('hand blower', 'blow dryer', 'blow drier', 'hair dryer', 'hair drier'),
('hand-held computer', 'hand-held microcomputer'),
('handkerchief', 'hankie', 'hanky', 'hankey'), ('hard disc', 'hard disk', 'fixed disk'),
('harmonica', 'mouth organ', 'harp', 'mouth harp'), ('harp',), ('harvester', 'reaper'),
('hatchet',), ('holster',), ('home theater', 'home theatre'), ('honeycomb',),
('hook', 'claw'), ('hoopskirt', 'crinoline'), ('horizontal bar', 'high bar'),
('horse cart', 'horse-cart'), ('hourglass',), ('iPod',), ('iron', 'smoothing iron'),
("jack-o'-lantern",), ('jean', 'blue jean', 'denim'), ('jeep', 'landrover'),
('jersey', 'T-shirt', 'tee shirt'), ('jigsaw puzzle',), ('jinrikisha', 'ricksha', 'rickshaw'),
('joystick',), ('kimono',), ('knee pad',), ('knot',), ('lab coat', 'laboratory coat'),
('ladle',), ('lampshade', 'lamp shade'), ('laptop', 'laptop computer'),
('lawn mower', 'mower'), ('lens cap', 'lens cover'),
('letter opener', 'paper knife', 'paperknife'), ('library',), ('lifeboat',),
('lighter', 'light', 'igniter', 'ignitor'), ('limousine', 'limo'), ('liner', 'ocean liner'),
('lipstick', 'lip rouge'), ('Loafer',), ('lotion',),
('loudspeaker', 'speaker', 'speaker unit', 'loudspeaker system', 'speaker system'),
('loupe', "jeweler's loupe"), ('lumbermill', 'sawmill'), ('magnetic compass',),
('mailbag', 'postbag'), ('mailbox', 'letter box'), ('maillot',), ('maillot', 'tank suit'),
('manhole cover',), ('maraca',), ('marimba', 'xylophone'), ('mask',), ('matchstick',),
('maypole',), ('maze', 'labyrinth'), ('measuring cup',),
('medicine chest', 'medicine cabinet'), ('megalith', 'megalithic structure'),
('microphone', 'mike'), ('microwave', 'microwave oven'), ('military uniform',), ('milk can',),
('minibus',), ('miniskirt', 'mini'), ('minivan',), ('missile',), ('mitten',),
('mixing bowl',), ('mobile home', 'manufactured home'), ('Model T',), ('modem',),
('monastery',), ('monitor',), ('moped',), ('mortar',), ('mortarboard',), ('mosque',),
('mosquito net',), ('motor scooter', 'scooter'),
('mountain bike', 'all-terrain bike', 'off-roader'), ('mountain tent',),
('mouse', 'computer mouse'), ('mousetrap',), ('moving van',), ('muzzle',), ('nail',),
('neck brace',), ('necklace',), ('nipple',), ('notebook', 'notebook computer'), ('obelisk',),
('oboe', 'hautboy', 'hautbois'), ('ocarina', 'sweet potato'),
('odometer', 'hodometer', 'mileometer', 'milometer'), ('oil filter',),
('organ', 'pipe organ'), ('oscilloscope', 'scope', 'cathode-ray oscilloscope', 'CRO'),
('overskirt',), ('oxcart',), ('oxygen mask',), ('packet',), ('paddle', 'boat paddle'),
('paddlewheel', 'paddle wheel'), ('padlock',), ('paintbrush',),
('pajama', 'pyjama', "pj's", 'jammies'), ('palace',), ('panpipe', 'pandean pipe', 'syrinx'),
('paper towel',), ('parachute', 'chute'), ('parallel bars', 'bars'), ('park bench',),
('parking meter',), ('passenger car', 'coach', 'carriage'), ('patio', 'terrace'),
('pay-phone', 'pay-station'), ('pedestal', 'plinth', 'footstall'),
('pencil box', 'pencil case'), ('pencil sharpener',), ('perfume', 'essence'), ('Petri dish',),
('photocopier',), ('pick', 'plectrum', 'plectron'), ('pickelhaube',),
('picket fence', 'paling'), ('pickup', 'pickup truck'), ('pier',),
('piggy bank', 'penny bank'), ('pill bottle',), ('pillow',), ('ping-pong ball',),
('pinwheel',), ('pirate', 'pirate ship'), ('pitcher', 'ewer'),
('plane', "carpenter's plane", 'woodworking plane'), ('planetarium',), ('plastic bag',),
('plate rack',), ('plow', 'plough'), ('plunger', "plumber's helper"),
('Polaroid camera', 'Polaroid Land camera'), ('pole',),
('police van', 'police wagon', 'paddy wagon', 'patrol wagon', 'wagon', 'black Maria'),
('poncho',), ('pool table', 'billiard table', 'snooker table'), ('pop bottle', 'soda bottle'),
('pot', 'flowerpot'), ("potter's wheel",), ('power drill',), ('prayer rug', 'prayer mat'),
('printer',), ('prison', 'prison house'), ('projectile', 'missile'), ('projector',),
('puck', 'hockey puck'), ('punching bag', 'punch bag', 'punching ball', 'punchball'),
('purse',), ('quill', 'quill pen'), ('quilt', 'comforter', 'comfort', 'puff'),
('racer', 'race car', 'racing car'), ('racket', 'racquet'), ('radiator',),
('radio', 'wireless'), ('radio telescope', 'radio reflector'), ('rain barrel',),
('recreational vehicle', 'RV', 'R.V.'), ('reel',), ('reflex camera',),
('refrigerator', 'icebox'), ('remote control', 'remote'),
('restaurant', 'eating house', 'eating place', 'eatery'),
('revolver', 'six-gun', 'six-shooter'), ('rifle',), ('rocking chair', 'rocker'),
('rotisserie',), ('rubber eraser', 'rubber', 'pencil eraser'), ('rugby ball',),
('rule', 'ruler'), ('running shoe',), ('safe',), ('safety pin',),
('saltshaker', 'salt shaker'), ('sandal',), ('sarong',), ('sax', 'saxophone'), ('scabbard',),
('scale', 'weighing machine'), ('school bus',), ('schooner',), ('scoreboard',),
('screen', 'CRT screen'), ('screw',), ('screwdriver',), ('seat belt', 'seatbelt'),
('sewing machine',), ('shield', 'buckler'), ('shoe shop', 'shoe-shop', 'shoe store'),
('shoji',), ('shopping basket',), ('shopping cart',), ('shovel',), ('shower cap',),
('shower curtain',), ('ski',), ('ski mask',), ('sleeping bag',), ('slide rule', 'slipstick'),
('sliding door',), ('slot', 'one-armed bandit'), ('snorkel',), ('snowmobile',),
('snowplow', 'snowplough'), ('soap dispenser',), ('soccer ball',), ('sock',),
('solar dish', 'solar collector', 'solar furnace'), ('sombrero',), ('soup bowl',),
('space bar',), ('space heater',), ('space shuttle',), ('spatula',), ('speedboat',),
('spider web', "spider's web"), ('spindle',), ('sports car', 'sport car'),
('spotlight', 'spot'), ('stage',), ('steam locomotive',), ('steel arch bridge',),
('steel drum',), ('stethoscope',), ('stole',), ('stone wall',), ('stopwatch', 'stop watch'),
('stove',), ('strainer',), ('streetcar', 'tram', 'tramcar', 'trolley', 'trolley car'),
('stretcher',), ('studio couch', 'day bed'), ('stupa', 'tope'),
('submarine', 'pigboat', 'sub', 'U-boat'), ('suit', 'suit of clothes'), ('sundial',),
('sunglass',), ('sunglasses', 'dark glasses', 'shades'),
('sunscreen', 'sunblock', 'sun blocker'), ('suspension bridge',), ('swab', 'swob', 'mop'),
('sweatshirt',), ('swimming trunks', 'bathing trunks'), ('swing',),
('switch', 'electric switch', 'electrical switch'), ('syringe',), ('table lamp',),
('tank', 'army tank', 'armored combat vehicle', 'armoured combat vehicle'), ('tape player',),
('teapot',), ('teddy', 'teddy bear'), ('television', 'television system'), ('tennis ball',),
('thatch', 'thatched roof'), ('theater curtain', 'theatre curtain'), ('thimble',),
('thresher', 'thrasher', 'threshing machine'), ('throne',), ('tile roof',), ('toaster',),
('tobacco shop', 'tobacconist shop', 'tobacconist'), ('toilet seat',), ('torch',),
('totem pole',), ('tow truck', 'tow car', 'wrecker'), ('toyshop',), ('tractor',),
('trailer truck', 'tractor trailer', 'trucking rig', 'rig', 'articulated lorry', 'semi'),
('tray',), ('trench coat',), ('tricycle', 'trike', 'velocipede'), ('trimaran',), ('tripod',),
('triumphal arch',), ('trolleybus', 'trolley coach', 'trackless trolley'), ('trombone',),
('tub', 'vat'), ('turnstile',), ('typewriter keyboard',), ('umbrella',),
('unicycle', 'monocycle'), ('upright', 'upright piano'), ('vacuum', 'vacuum cleaner'),
('vase',), ('vault',), ('velvet',), ('vending machine',), ('vestment',), ('viaduct',),
('violin', 'fiddle'), ('volleyball',), ('waffle iron',), ('wall clock',),
('wallet', 'billfold', 'notecase', 'pocketbook'), ('wardrobe', 'closet', 'press'),
('warplane', 'military plane'),
('washbasin', 'handbasin', 'washbowl', 'lavabo', 'wash-hand basin'),
('washer', 'automatic washer', 'washing machine'), ('water bottle',), ('water jug',),
('water tower',), ('whiskey jug',), ('whistle',), ('wig',), ('window screen',),
('window shade',), ('Windsor tie',), ('wine bottle',), ('wing',), ('wok',), ('wooden spoon',),
('wool', 'woolen', 'woollen'),
('worm fence', 'snake fence', 'snake-rail fence', 'Virginia fence'), ('wreck',), ('yawl',),
('yurt',), ('web site', 'website', 'internet site', 'site'), ('comic book',),
('crossword puzzle', 'crossword'), ('street sign',),
('traffic light', 'traffic signal', 'stoplight'),
('book jacket', 'dust cover', 'dust jacket', 'dust wrapper'), ('menu',), ('plate',),
('guacamole',), ('consomme',), ('hot pot', 'hotpot'), ('trifle',), ('ice cream', 'icecream'),
('ice lolly', 'lolly', 'lollipop', 'popsicle'), ('French loaf',), ('bagel', 'beigel'),
('pretzel',), ('cheeseburger',), ('hotdog', 'hot dog', 'red hot'), ('mashed potato',),
('head cabbage',), ('broccoli',), ('cauliflower',), ('zucchini', 'courgette'),
('spaghetti squash',), ('acorn squash',), ('butternut squash',), ('cucumber', 'cuke'),
('artichoke', 'globe artichoke'), ('bell pepper',), ('cardoon',), ('mushroom',),
('Granny Smith',), ('strawberry',), ('orange',), ('lemon',), ('fig',),
('pineapple', 'ananas'), ('banana',), ('jackfruit', 'jak', 'jack'), ('custard apple',),
('pomegranate',), ('hay',), ('carbonara',), ('chocolate sauce', 'chocolate syrup'),
('dough',), ('meat loaf', 'meatloaf'), ('pizza', 'pizza pie'), ('potpie',), ('burrito',),
('red wine',), ('espresso',), ('cup',), ('eggnog',), ('alp',), ('bubble',),
('cliff', 'drop', 'drop-off'), ('coral reef',), ('geyser',), ('lakeside', 'lakeshore'),
('promontory', 'headland', 'head', 'foreland'), ('sandbar', 'sand bar'),
('seashore', 'coast', 'seacoast', 'sea-coast'), ('valley', 'vale'), ('volcano',),
('ballplayer', 'baseball player'), ('groom', 'bridegroom'), ('scuba diver',), ('rapeseed',),
('daisy',), ("yellow lady's slipper", 'yellow lady-slipper', 'Cypripedium calceolus',
'Cypripedium parviflorum'), ('corn',), ('acorn',),
('hip', 'rose hip', 'rosehip'), ('buckeye', 'horse chestnut', 'conker'), ('coral fungus',),
('agaric',), ('gyromitra',), ('stinkhorn', 'carrion fungus'), ('earthstar',),
('hen-of-the-woods', 'hen of the woods', 'Polyporus frondosus', 'Grifola frondosa'),
('bolete',), ('ear', 'spike', 'capitulum'),
('toilet tissue', 'toilet paper', 'bathroom tissue')]]
class ImageNetCaltech(ImageList):
"""ImageNet-Caltech is constructed from `Caltech-256 `_ and
`ImageNet-1K `_ .
They share 84 common classes. ImageNet-Caltech keeps all classes of ImageNet-1K.
The label is based on the ImageNet-1K (class 0-999) . The private classes of Caltech-256 is discarded.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \
``'I'``: ImageNet-1K training set.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory
since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically.
In `root`, there will exist following files after downloading.
::
train/
n01440764/
...
val/
256_ObjectCategories/
001.ak47/
...
image_list/
caltech_256_list.txt
...
"""
image_list = {
"I": "image_list/imagenet_train_1000_list.txt",
"C": "image_list/caltech_84_list.txt",
}
CLASSES = _CLASSES
def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):
assert task in self.image_list
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), download_list))
if not os.path.exists(os.path.join(root, 'train')):
print("Please put train and val directory of ImageNet-1K manually under {} "
"since ImageNet-1K is no longer publicly accessible.".format(root))
exit(-1)
super(ImageNetCaltech, self).__init__(root, ImageNetCaltech.CLASSES, data_list_file=data_list_file, **kwargs)
class ImageNetCaltechUniversal(ImageList):
"""ImageNet-Caltech-Universal is constructed from `Caltech-256 `_
and `ImageNet-1K `_ .
They share 84 common classes. ImageNet-Caltech keeps all classes of ImageNet-1K.
The label is based on the ImageNet-1K (class 0-999) . The private classes of Caltech-256 is grouped into class 1000 ("unknown").
Thus, ImageNetCaltechUniversal has 1001 classes in total.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \
``'I'``: ImageNet-1K training set.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory
since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically.
In `root`, there will exist following files after downloading.
::
train/
n01440764/
...
val/
256_ObjectCategories/
001.ak47/
...
image_list/
caltech_256_list.txt
...
"""
image_list = {
"I": "image_list/imagenet_train_1000_list.txt",
"C": "image_list/caltech_85_list.txt",
}
CLASSES = _CLASSES + ["unknown"]
def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs):
assert task in self.image_list
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), download_list))
if not os.path.exists(os.path.join(root, 'train')):
print("Please put train and val directory of ImageNet-1K manually under {} "
"since ImageNet-1K is no longer publicly accessible.".format(root))
exit(-1)
super(ImageNetCaltechUniversal, self).__init__(root, ImageNetCaltechUniversal.CLASSES, data_list_file=data_list_file, **kwargs)
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/a0d7ea37026946f98965/?dl=1"),
("256_ObjectCategories", "256_ObjectCategories.tar",
"http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar"),
]
================================================
FILE: tllib/vision/datasets/patchcamelyon.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class PatchCamelyon(ImageList):
"""
The `PatchCamelyon `_ dataset contains \
327680 images of histopathologic scans of lymph node sections. \
The classification task consists in predicting the presence of metastatic tissue \
in given image (i.e., two classes). All images are 96x96 pixels
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
CLASSES = ['0', '1']
def __init__(self, root, split, download=False, **kwargs):
if download:
download_data(root, "patch_camelyon", "patch_camelyon.tgz", "https://cloud.tsinghua.edu.cn/f/21360b3441a54274b843/?dl=1")
else:
check_exits(root, "patch_camelyon")
root = os.path.join(root, "patch_camelyon")
super(PatchCamelyon, self).__init__(root, PatchCamelyon.CLASSES, os.path.join(root, "imagelist", "{}.txt".format(split)), **kwargs)
================================================
FILE: tllib/vision/datasets/regression/__init__.py
================================================
from .image_regression import ImageRegression
from .dsprites import DSprites
from .mpi3d import MPI3D
================================================
FILE: tllib/vision/datasets/regression/dsprites.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, Sequence
import os
from .._util import download as download_data, check_exits
from .image_regression import ImageRegression
class DSprites(ImageRegression):
"""`DSprites `_ Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \
``'N'``: Noisy and ``'S'``: Scream.
split (str, optional): The dataset split, supports ``train``, or ``test``.
factors (sequence[str]): Factors selected. Default: ('scale', 'position x', 'position y').
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
color/
...
noisy/
scream/
image_list/
color_train.txt
noisy_train.txt
scream_train.txt
color_test.txt
noisy_test.txt
scream_test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/4392ef903ed14017a042/?dl=1"),
("color", "color.tgz", "https://cloud.tsinghua.edu.cn/f/6d243c589d384ff5a212/?dl=1"),
("noisy", "noisy.tgz", "https://cloud.tsinghua.edu.cn/f/9a23ede3be1740328637/?dl=1"),
("scream", "scream.tgz", "https://cloud.tsinghua.edu.cn/f/8fc4d34311bb4db6bcde/?dl=1"),
]
image_list = {
"C": "color",
"N": "noisy",
"S": "scream"
}
FACTORS = ('none', 'shape', 'scale', 'orientation', 'position x', 'position y')
def __init__(self, root: str, task: str, split: Optional[str] = 'train',
factors: Sequence[str] = ('scale', 'position x', 'position y'),
download: Optional[bool] = True, target_transform=None, **kwargs):
assert task in self.image_list
assert split in ['train', 'test']
for factor in factors:
assert factor in self.FACTORS
factor_index = [self.FACTORS.index(factor) for factor in factors]
if target_transform is None:
target_transform = lambda x: x[list(factor_index)]
else:
target_transform = lambda x: target_transform(x[list(factor_index)])
data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split))
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(DSprites, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform,
**kwargs)
================================================
FILE: tllib/vision/datasets/regression/image_regression.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from typing import Optional, Callable, Tuple, Any, List, Sequence
import torchvision.datasets as datasets
from torchvision.datasets.folder import default_loader
import numpy as np
class ImageRegression(datasets.VisionDataset):
"""A generic Dataset class for domain adaptation in image regression
Args:
root (str): Root directory of dataset
factors (sequence[str]): Factors selected. Default: ('scale', 'position x', 'position y').
data_list_file (str): File to read the image list from.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note::
In `data_list_file`, each line has `1+len(factors)` values in the following format.
::
source_dir/dog_xxx.png x11, x12, ...
source_dir/cat_123.png x21, x22, ...
target_dir/dog_xxy.png x31, x32, ...
target_dir/cat_nsdf3.png x41, x42, ...
The first value is the relative path of an image, and the rest values are the ground truth of the corresponding factors.
If your data_list_file has different formats, please over-ride :meth:`ImageRegression.parse_data_file`.
"""
def __init__(self, root: str, factors: Sequence[str], data_list_file: str,
transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
super().__init__(root, transform=transform, target_transform=target_transform)
self.samples = self.parse_data_file(data_list_file)
self.factors = factors
self.loader = default_loader
self.data_list_file = data_list_file
def __getitem__(self, index: int) -> Tuple[Any, Tuple[float]]:
"""
Args:
index (int): Index
Returns:
(image, target) where target is a numpy float array.
"""
path, target = self.samples[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None and target is not None:
target = self.target_transform(target)
return img, target
def __len__(self) -> int:
return len(self.samples)
def parse_data_file(self, file_name: str) -> List[Tuple[str, Any]]:
"""Parse file to data list
Args:
file_name (str): The path of data file
Returns:
List of (image path, (factors)) tuples
"""
with open(file_name, "r") as f:
data_list = []
for line in f.readlines():
data = line.split()
path = str(data[0])
target = np.array([float(d) for d in data[1:]], dtype=np.float)
if not os.path.isabs(path):
path = os.path.join(self.root, path)
data_list.append((path, target))
return data_list
@property
def num_factors(self) -> int:
return len(self.factors)
================================================
FILE: tllib/vision/datasets/regression/mpi3d.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Optional, Sequence
import os
from .._util import download as download_data, check_exits
from .image_regression import ImageRegression
class MPI3D(ImageRegression):
"""`MPI3D `_ Dataset.
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \
``'N'``: Noisy and ``'S'``: Scream.
split (str, optional): The dataset split, supports ``train``, or ``test``.
factors (sequence[str]): Factors selected. Default: ('horizontal axis', 'vertical axis').
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
real/
...
realistic/
toy/
image_list/
real_train.txt
realistic_train.txt
toy_train.txt
real_test.txt
realistic_test.txt
toy_test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/feacec494d5347b7a6aa/?dl=1"),
("real", "real.tgz", "https://cloud.tsinghua.edu.cn/f/605dd842cd9d4071a0ae/?dl=1"),
("realistic", "realistic.tgz", "https://cloud.tsinghua.edu.cn/f/05743f3071054cc29e25/?dl=1"),
("toy", "toy.tgz", "https://cloud.tsinghua.edu.cn/f/1511dff7853d4abea38f/?dl=1"),
]
image_list = {
"RL": "real",
"RC": "realistic",
"T": "toy"
}
FACTORS = ('horizontal axis', 'vertical axis')
def __init__(self, root: str, task: str, split: Optional[str] = 'train',
factors: Sequence[str] = ('horizontal axis', 'vertical axis'),
download: Optional[bool] = True, target_transform=None, **kwargs):
assert task in self.image_list
assert split in ['train', 'test']
for factor in factors:
assert factor in self.FACTORS
factor_index = [self.FACTORS.index(factor) for factor in factors]
if target_transform is None:
target_transform = lambda x: x[list(factor_index)] / 39.
else:
target_transform = lambda x: target_transform(x[list(factor_index)]) / 39.
data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split))
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(MPI3D, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform, **kwargs)
================================================
FILE: tllib/vision/datasets/reid/__init__.py
================================================
from .market1501 import Market1501
from .dukemtmc import DukeMTMC
from .msmt17 import MSMT17
from .personx import PersonX
from .unreal import UnrealPerson
__all__ = ['Market1501', 'DukeMTMC', 'MSMT17', 'PersonX', 'UnrealPerson']
================================================
FILE: tllib/vision/datasets/reid/basedataset.py
================================================
"""
Modified from https://github.com/yxgeee/MMT
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import os.path as osp
import numpy as np
class BaseDataset(object):
"""
Base class of reid dataset
"""
def get_imagedata_info(self, data):
pids, cams = [], []
for _, pid, camid in data:
pids += [pid]
cams += [camid]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_imgs = len(data)
return num_pids, num_imgs, num_cams
def get_videodata_info(self, data, return_tracklet_stats=False):
pids, cams, tracklet_stats = [], [], []
for img_paths, pid, camid in data:
pids += [pid]
cams += [camid]
tracklet_stats += [len(img_paths)]
pids = set(pids)
cams = set(cams)
num_pids = len(pids)
num_cams = len(cams)
num_tracklets = len(data)
if return_tracklet_stats:
return num_pids, num_tracklets, num_cams, tracklet_stats
return num_pids, num_tracklets, num_cams
def print_dataset_statistics(self, train, query, galler):
raise NotImplementedError
def check_before_run(self, required_files):
"""Checks if required files exist before going deeper.
Args:
required_files (str or list): string file name(s).
"""
if isinstance(required_files, str):
required_files = [required_files]
for fpath in required_files:
if not osp.exists(fpath):
raise RuntimeError('"{}" is not found'.format(fpath))
@property
def images_dir(self):
return None
class BaseImageDataset(BaseDataset):
"""
Base class of image reid dataset
"""
def print_dataset_statistics(self, train, query, gallery):
num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train)
num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query)
num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery)
print("Dataset statistics:")
print(" ----------------------------------------")
print(" subset | # ids | # images | # cameras")
print(" ----------------------------------------")
print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams))
print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams))
print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams))
print(" ----------------------------------------")
class BaseVideoDataset(BaseDataset):
"""
Base class of video reid dataset
"""
def print_dataset_statistics(self, train, query, gallery):
num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \
self.get_videodata_info(train, return_tracklet_stats=True)
num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \
self.get_videodata_info(query, return_tracklet_stats=True)
num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \
self.get_videodata_info(gallery, return_tracklet_stats=True)
tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats
min_num = np.min(tracklet_stats)
max_num = np.max(tracklet_stats)
avg_num = np.mean(tracklet_stats)
print("Dataset statistics:")
print(" -------------------------------------------")
print(" subset | # ids | # tracklets | # cameras")
print(" -------------------------------------------")
print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams))
print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams))
print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams))
print(" -------------------------------------------")
print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num))
print(" -------------------------------------------")
================================================
FILE: tllib/vision/datasets/reid/convert.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import os.path as osp
from torch.utils.data import Dataset
from PIL import Image
def convert_to_pytorch_dataset(dataset, root=None, transform=None, return_idxes=False):
class ReidDataset(Dataset):
def __init__(self, dataset, root, transform):
super(ReidDataset, self).__init__()
self.dataset = dataset
self.root = root
self.transform = transform
self.return_idxes = return_idxes
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
fname, pid, cid = self.dataset[index]
fpath = fname
if self.root is not None:
fpath = osp.join(self.root, fname)
img = Image.open(fpath).convert('RGB')
if self.transform is not None:
img = self.transform(img)
if not self.return_idxes:
return img, fname, pid, cid
else:
return img, fname, pid, cid, index
return ReidDataset(dataset, root, transform)
================================================
FILE: tllib/vision/datasets/reid/dukemtmc.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from .basedataset import BaseImageDataset
from typing import Callable
from PIL import Image
import os
import os.path as osp
import glob
import re
from tllib.vision.datasets._util import download
class DukeMTMC(BaseImageDataset):
"""DukeMTMC-reID dataset from `Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking
(ECCV 2016) `_.
Dataset statistics:
- identities: 1404 (train + query)
- images:16522 (train) + 2228 (query) + 17661 (gallery)
- cameras: 8
Args:
root (str): Root directory of dataset
verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True
"""
dataset_dir = '.'
archive_name = 'DukeMTMC-reID.tgz'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/cb80f49905ee4e8eb9f0/?dl=1'
def __init__(self, root, verbose=True):
super(DukeMTMC, self).__init__()
download(root, self.dataset_dir, self.archive_name, self.dataset_url)
self.relative_dataset_dir = self.dataset_dir
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query')
self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test')
required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir]
self.check_before_run(required_files)
train = self.process_dir(self.train_dir, relabel=True)
query = self.process_dir(self.query_dir, relabel=False)
gallery = self.process_dir(self.gallery_dir, relabel=False)
if verbose:
print("=> DukeMTMC-reID loaded")
self.print_dataset_statistics(train, query, gallery)
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
def process_dir(self, dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
dataset = []
for img_path in img_paths:
pid, cid = map(int, pattern.search(img_path).groups())
assert 1 <= cid <= 8
cid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
dataset.append((img_path, pid, cid))
return dataset
def translate(self, transform: Callable, target_root: str):
""" Translate an image and save it into a specified directory
Args:
transform (callable): a transform function that maps images from one domain to another domain
target_root (str): the root directory to save images
"""
os.makedirs(target_root, exist_ok=True)
translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir)
translated_train_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/bounding_box_train')
translated_query_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/query')
translated_gallery_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/bounding_box_test')
print("Translating dataset with image to image transform...")
self.translate_dir(transform, self.train_dir, translated_train_dir)
self.translate_dir(None, self.query_dir, translated_query_dir)
self.translate_dir(None, self.gallery_dir, translated_gallery_dir)
print("Translation process is done, save dataset to {}".format(translated_dataset_dir))
def translate_dir(self, transform, origin_dir: str, target_dir: str):
image_list = os.listdir(origin_dir)
for image_name in image_list:
if not image_name.endswith(".jpg"):
continue
image_path = osp.join(origin_dir, image_name)
image = Image.open(image_path)
translated_image_path = osp.join(target_dir, image_name)
translated_image = image
if transform:
translated_image = transform(image)
os.makedirs(os.path.dirname(translated_image_path), exist_ok=True)
translated_image.save(translated_image_path)
================================================
FILE: tllib/vision/datasets/reid/market1501.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from .basedataset import BaseImageDataset
from typing import Callable
from PIL import Image
import os
import os.path as osp
import glob
import re
from tllib.vision.datasets._util import download
class Market1501(BaseImageDataset):
"""Market1501 dataset from `Scalable Person Re-identification: A Benchmark (ICCV 2015)
`_.
Dataset statistics:
- identities: 1501 (+1 for background)
- images: 12936 (train) + 3368 (query) + 15913 (gallery)
- cameras: 6
Args:
root (str): Root directory of dataset
verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True
"""
dataset_dir = 'Market-1501-v15.09.15'
archive_name = 'Market-1501-v15.09.15.tgz'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/29e5f015a7314531b645/?dl=1'
def __init__(self, root, verbose=True):
super(Market1501, self).__init__()
download(root, self.dataset_dir, self.archive_name, self.dataset_url)
self.relative_dataset_dir = self.dataset_dir
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir]
self.check_before_run(required_files)
train = self.process_dir(self.train_dir, relabel=True)
query = self.process_dir(self.query_dir, relabel=False)
gallery = self.process_dir(self.gallery_dir, relabel=False)
if verbose:
print("=> Market1501 loaded")
self.print_dataset_statistics(train, query, gallery)
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
def process_dir(self, dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c(\d)')
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
if pid == -1:
continue # junk images are just ignored
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
dataset = []
for img_path in img_paths:
pid, cid = map(int, pattern.search(img_path).groups())
if pid == -1:
continue # junk images are just ignored
assert 0 <= pid <= 1501 # pid == 0 means background
assert 1 <= cid <= 6
cid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
dataset.append((img_path, pid, cid))
return dataset
def translate(self, transform: Callable, target_root: str):
""" Translate an image and save it into a specified directory
Args:
transform (callable): a transform function that maps images from one domain to another domain
target_root (str): the root directory to save images
"""
os.makedirs(target_root, exist_ok=True)
translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir)
translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train')
translated_query_dir = osp.join(translated_dataset_dir, 'query')
translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test')
print("Translating dataset with image to image transform...")
self.translate_dir(transform, self.train_dir, translated_train_dir)
self.translate_dir(None, self.query_dir, translated_query_dir)
self.translate_dir(None, self.gallery_dir, translated_gallery_dir)
print("Translation process is done, save dataset to {}".format(translated_dataset_dir))
def translate_dir(self, transform, origin_dir: str, target_dir: str):
image_list = os.listdir(origin_dir)
for image_name in image_list:
if not image_name.endswith(".jpg"):
continue
image_path = osp.join(origin_dir, image_name)
image = Image.open(image_path)
translated_image_path = osp.join(target_dir, image_name)
translated_image = image
if transform:
translated_image = transform(image)
os.makedirs(os.path.dirname(translated_image_path), exist_ok=True)
translated_image.save(translated_image_path)
================================================
FILE: tllib/vision/datasets/reid/msmt17.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from .basedataset import BaseImageDataset
from typing import Callable
from PIL import Image
import os
import os.path as osp
from tllib.vision.datasets._util import download
class MSMT17(BaseImageDataset):
"""MSMT17 dataset from `Person Transfer GAN to Bridge Domain Gap for Person Re-Identification (CVPR 2018)
`_.
Dataset statistics:
- identities: 4101
- images: 32621 (train) + 11659 (query) + 82161 (gallery)
- cameras: 15
Args:
root (str): Root directory of dataset
verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True
"""
dataset_dir = '.'
archive_name = 'MSMT17_V1.zip'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/c254ea490cfa4115940d/?dl=1'
def __init__(self, root, verbose=True):
super(MSMT17, self).__init__()
download(root, self.dataset_dir, self.archive_name, self.dataset_url)
self.relative_dataset_dir = self.dataset_dir
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir]
self.check_before_run(required_files)
self.train = self.process_dir(self.train_dir)
self.query = self.process_dir(self.query_dir)
self.gallery = self.process_dir(self.gallery_dir)
if verbose:
print("=> MSMT17 loaded")
self.print_dataset_statistics(self.train, self.query, self.gallery)
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
def process_dir(self, dir_path):
image_list = os.listdir(dir_path)
dataset = []
pid_container = set()
for image_path in image_list:
pid, cid, _ = image_path.split('_')
pid = int(pid)
cid = int(cid[1:]) - 1 # index starts from 0
full_image_path = osp.join(dir_path, image_path)
dataset.append((full_image_path, pid, cid))
pid_container.add(pid)
# check if pid starts from 0 and increments with 1
for idx, pid in enumerate(pid_container):
assert idx == pid, "See code comment for explanation"
return dataset
def translate(self, transform: Callable, target_root: str):
""" Translate an image and save it into a specified directory
Args:
transform (callable): a transform function that maps images from one domain to another domain
target_root (str): the root directory to save images
"""
os.makedirs(target_root, exist_ok=True)
translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir)
translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train')
translated_query_dir = osp.join(translated_dataset_dir, 'query')
translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test')
print("Translating dataset with image to image transform...")
self.translate_dir(transform, self.train_dir, translated_train_dir)
self.translate_dir(None, self.query_dir, translated_query_dir)
self.translate_dir(None, self.gallery_dir, translated_gallery_dir)
print("Translation process is done, save dataset to {}".format(translated_dataset_dir))
def translate_dir(self, transform, origin_dir: str, target_dir: str):
image_list = os.listdir(origin_dir)
for image_name in image_list:
if not image_name.endswith(".jpg"):
continue
image_path = osp.join(origin_dir, image_name)
image = Image.open(image_path)
translated_image_path = osp.join(target_dir, image_name)
translated_image = image
if transform:
translated_image = transform(image)
os.makedirs(os.path.dirname(translated_image_path), exist_ok=True)
translated_image.save(translated_image_path)
================================================
FILE: tllib/vision/datasets/reid/personx.py
================================================
"""
Modified from https://github.com/yxgeee/SpCL
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from .basedataset import BaseImageDataset
from typing import Callable
from PIL import Image
import os
import os.path as osp
import glob
import re
from tllib.vision.datasets._util import download
class PersonX(BaseImageDataset):
"""PersonX dataset from `Dissecting Person Re-identification from the Viewpoint of Viewpoint (CVPR 2019)
`_.
Dataset statistics:
- identities: 1266
- images: 9840 (train) + 5136 (query) + 30816 (gallery)
- cameras: 6
Args:
root (str): Root directory of dataset
verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True
"""
dataset_dir = '.'
archive_name = 'PersonX.zip'
dataset_url = 'https://cloud.tsinghua.edu.cn/f/f506cd11d6b646729bd1/?dl=1'
def __init__(self, root, verbose=True):
super(PersonX, self).__init__()
download(root, self.dataset_dir, self.archive_name, self.dataset_url)
self.relative_dataset_dir = self.dataset_dir
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
self.query_dir = osp.join(self.dataset_dir, 'query')
self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir]
self.check_before_run(required_files)
train = self.process_dir(self.train_dir, relabel=True)
query = self.process_dir(self.query_dir, relabel=False)
gallery = self.process_dir(self.gallery_dir, relabel=False)
if verbose:
print("=> PersonX loaded")
self.print_dataset_statistics(train, query, gallery)
self.train = train
self.query = query
self.gallery = gallery
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query)
self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery)
def process_dir(self, dir_path, relabel=False):
img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
pattern = re.compile(r'([-\d]+)_c([-\d]+)')
cam2label = {3: 1, 4: 2, 8: 3, 10: 4, 11: 5, 12: 6}
pid_container = set()
for img_path in img_paths:
pid, _ = map(int, pattern.search(img_path).groups())
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(pid_container)}
dataset = []
for img_path in img_paths:
pid, cid = map(int, pattern.search(img_path).groups())
assert (cid in cam2label.keys())
cid = cam2label[cid]
cid -= 1 # index starts from 0
if relabel:
pid = pid2label[pid]
dataset.append((img_path, pid, cid))
return dataset
def translate(self, transform: Callable, target_root: str):
""" Translate an image and save it into a specified directory
Args:
transform (callable): a transform function that maps images from one domain to another domain
target_root (str): the root directory to save images
"""
os.makedirs(target_root, exist_ok=True)
translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir)
translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train')
translated_query_dir = osp.join(translated_dataset_dir, 'query')
translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test')
print("Translating dataset with image to image transform...")
self.translate_dir(transform, self.train_dir, translated_train_dir)
self.translate_dir(None, self.query_dir, translated_query_dir)
self.translate_dir(None, self.gallery_dir, translated_gallery_dir)
print("Translation process is done, save dataset to {}".format(translated_dataset_dir))
def translate_dir(self, transform, origin_dir: str, target_dir: str):
image_list = os.listdir(origin_dir)
for image_name in image_list:
if not image_name.endswith(".jpg"):
continue
image_path = osp.join(origin_dir, image_name)
image = Image.open(image_path)
translated_image_path = osp.join(target_dir, image_name)
translated_image = image
if transform:
translated_image = transform(image)
os.makedirs(os.path.dirname(translated_image_path), exist_ok=True)
translated_image.save(translated_image_path)
================================================
FILE: tllib/vision/datasets/reid/unreal.py
================================================
"""
Modified from https://github.com/SikaStar/IDM
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from .basedataset import BaseImageDataset
from typing import Callable
import os.path as osp
from tllib.vision.datasets._util import download
class UnrealPerson(BaseImageDataset):
"""UnrealPerson dataset from `UnrealPerson: An Adaptive Pipeline towards Costless Person Re-identification
(CVPR 2021) `_.
Dataset statistics:
- identities: 3000
- images: 120,000
- cameras: 34
Args:
root (str): Root directory of dataset
verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True
"""
dataset_dir = '.'
download_list = [
("list_unreal_train.txt", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/a51b22fd760743e7bca6/?dl=1"),
("unreal_v1.1", "unreal_v1.1.tar", "https://cloud.tsinghua.edu.cn/f/a8806bb3bf1744dda5b1/?dl=1"),
("unreal_v1.2", "unreal_v1.2.tar", "https://cloud.tsinghua.edu.cn/f/449224485a654c5baa8f/?dl=1"),
("unreal_v1.3", "unreal_v1.3.tar", "https://cloud.tsinghua.edu.cn/f/069f3162f74849c09c10/?dl=1"),
("unreal_v2.1", "unreal_v2.1.tar", "https://cloud.tsinghua.edu.cn/f/a791aaa42674466eb183/?dl=1"),
("unreal_v2.2", "unreal_v2.2.tar", "https://cloud.tsinghua.edu.cn/f/b601d9f54f964248bd0e/?dl=1"),
("unreal_v2.3", "unreal_v2.3.tar", "https://cloud.tsinghua.edu.cn/f/311ec60e810b42d48d12/?dl=1"),
("unreal_v3.1", "unreal_v3.1.tar", "https://cloud.tsinghua.edu.cn/f/d51b7c1d125e4632bcf9/?dl=1"),
("unreal_v3.2", "unreal_v3.2.tar", "https://cloud.tsinghua.edu.cn/f/4efbd969ea2f4e8197e8/?dl=1"),
("unreal_v3.3", "unreal_v3.3.tar", "https://cloud.tsinghua.edu.cn/f/a3cc3d9c460247848fb7/?dl=1"),
("unreal_v4.1", "unreal_v4.1.tar", "https://cloud.tsinghua.edu.cn/f/ca05183ac9cd4be5a53b/?dl=1"),
("unreal_v4.2", "unreal_v4.2.tar", "https://cloud.tsinghua.edu.cn/f/b90722cbd754496f9f40/?dl=1"),
("unreal_v4.3", "unreal_v4.3.tar", "https://cloud.tsinghua.edu.cn/f/547ae646c3d346038297/?dl=1"),
]
def __init__(self, root, verbose=True):
super(UnrealPerson, self).__init__()
list(map(lambda args: download(root, *args), self.download_list))
self.dataset_dir = osp.join(root, self.dataset_dir)
self.train_list = osp.join(self.dataset_dir, 'list_unreal_train.txt')
required_files = [self.dataset_dir, self.train_list]
self.check_before_run(required_files)
train = self.process_dir(self.train_list)
self.train = train
self.query = []
self.gallery = []
self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train)
if verbose:
print("=> UnrealPerson loaded")
print(" ----------------------------------------")
print(" subset | # ids | # cams | # images")
print(" ----------------------------------------")
print(" train | {:5d} | {:5d} | {:8d}"
.format(self.num_train_pids, self.num_train_cams, self.num_train_imgs))
print(" ----------------------------------------")
def process_dir(self, list_file):
with open(list_file, 'r') as f:
lines = f.readlines()
dataset = []
pid_container = set()
for line in lines:
line = line.strip()
pid = line.split(' ')[1]
pid_container.add(pid)
pid2label = {pid: label for label, pid in enumerate(sorted(pid_container))}
for line in lines:
line = line.strip()
fname, pid, cid = line.split(' ')[0], line.split(' ')[1], int(line.split(' ')[2])
img_path = osp.join(self.dataset_dir, fname)
dataset.append((img_path, pid2label[pid], cid))
return dataset
def translate(self, transform: Callable, target_root: str):
raise NotImplementedError
================================================
FILE: tllib/vision/datasets/resisc45.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from torchvision.datasets.folder import ImageFolder
import random
class Resisc45(ImageFolder):
"""`Resisc45 `_ dataset \
is a scene classification task from remote sensing images. There are 45 classes, \
containing 700 images each, including tennis court, ship, island, lake, \
parking lot, sparse residential, or stadium. \
The image size is RGB 256x256 pixels.
.. note:: You need to download the source data manually into `root` directory.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
def __init__(self, root, split='train', download=False, **kwargs):
super(Resisc45, self).__init__(root, **kwargs)
random.seed(0)
random.shuffle(self.samples)
if split == 'train':
self.samples = self.samples[:25200]
else:
self.samples = self.samples[25200:]
@property
def num_classes(self) -> int:
"""Number of classes"""
return len(self.classes)
================================================
FILE: tllib/vision/datasets/retinopathy.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from .imagelist import ImageList
class Retinopathy(ImageList):
"""`Retinopathy `_ dataset \
consists of image-label pairs with high-resolution retina images, and labels that indicate \
the presence of Diabetic Retinopahy (DR) in a 0-4 scale (No DR, Mild, Moderate, Severe, \
or Proliferative DR).
.. note:: You need to download the source data manually into `root` directory.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
CLASSES = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR']
def __init__(self, root, split, download=False, **kwargs):
super(Retinopathy, self).__init__(os.path.join(root, split), Retinopathy.CLASSES, os.path.join(root, "image_list", "{}.txt".format(split)), **kwargs)
================================================
FILE: tllib/vision/datasets/segmentation/__init__.py
================================================
from .segmentation_list import SegmentationList
from .cityscapes import Cityscapes, FoggyCityscapes
from .gta5 import GTA5
from .synthia import Synthia
__all__ = ["SegmentationList", "Cityscapes", "GTA5", "Synthia", "FoggyCityscapes"]
================================================
FILE: tllib/vision/datasets/segmentation/cityscapes.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from .segmentation_list import SegmentationList
from .._util import download as download_data
class Cityscapes(SegmentationList):
"""`Cityscapes `_ is a real-world semantic segmentation dataset collected
in driving scenarios.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``val``.
data_folder (str, optional): Sub-directory of the image. Default: 'leftImg8bit'.
label_folder (str, optional): Sub-directory of the label. Default: 'gtFine'.
mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None.
transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \
and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.
.. note:: You need to download Cityscapes manually.
Ensure that there exist following files in the `root` directory before you using this class.
::
leftImg8bit/
train/
val/
test/
gtFine/
train/
val/
test/
"""
CLASSES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign',
'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
'bicycle']
ID_TO_TRAIN_ID = {
7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5,
19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12,
26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18
}
TRAIN_ID_TO_COLOR = [(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156),
(190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0),
(107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60),
(255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100),
(0, 0, 230), (119, 11, 32), [0, 0, 0]]
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/a23536bb8a724e91af39/?dl=1"),
]
EVALUATE_CLASSES = CLASSES
def __init__(self, root, split='train', data_folder='leftImg8bit', label_folder='gtFine', **kwargs):
assert split in ['train', 'val']
# download meta information from Internet
list(map(lambda args: download_data(root, *args), self.download_list))
data_list_file = os.path.join(root, "image_list", "{}.txt".format(split))
self.split = split
super(Cityscapes, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file,
os.path.join(data_folder, split), os.path.join(label_folder, split),
id_to_train_id=Cityscapes.ID_TO_TRAIN_ID,
train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs)
def parse_label_file(self, label_list_file):
with open(label_list_file, "r") as f:
label_list = [line.strip().replace("leftImg8bit", "gtFine_labelIds") for line in f.readlines()]
return label_list
class FoggyCityscapes(Cityscapes):
"""`Foggy Cityscapes `_ is a real-world semantic segmentation dataset collected
in foggy driving scenarios.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``val``.
data_folder (str, optional): Sub-directory of the image. Default: 'leftImg8bit'.
label_folder (str, optional): Sub-directory of the label. Default: 'gtFine'.
beta (float, optional): The parameter for foggy. Choices includes: 0.005, 0.01, 0.02. Default: 0.02
mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None.
transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \
and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.
.. note:: You need to download Cityscapes manually.
Ensure that there exist following files in the `root` directory before you using this class.
::
leftImg8bit_foggy/
train/
val/
test/
gtFine/
train/
val/
test/
"""
def __init__(self, root, split='train', data_folder='leftImg8bit_foggy', label_folder='gtFine', beta=0.02, **kwargs):
assert beta in [0.02, 0.01, 0.005]
self.beta = beta
super(FoggyCityscapes, self).__init__(root, split, data_folder, label_folder, **kwargs)
def parse_data_file(self, file_name):
"""Parse file to image list
Args:
file_name (str): The path of data file
Returns:
List of image path
"""
with open(file_name, "r") as f:
data_list = [line.strip().replace("leftImg8bit", "leftImg8bit_foggy_beta_{}".format(self.beta)) for line in f.readlines()]
return data_list
================================================
FILE: tllib/vision/datasets/segmentation/gta5.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from .segmentation_list import SegmentationList
from .cityscapes import Cityscapes
from .._util import download as download_data
class GTA5(SegmentationList):
"""`GTA5 `_
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``.
data_folder (str, optional): Sub-directory of the image. Default: 'images'.
label_folder (str, optional): Sub-directory of the label. Default: 'labels'.
mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None.
transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \
and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.
.. note:: You need to download GTA5 manually.
Ensure that there exist following directories in the `root` directory before you using this class.
::
images/
labels/
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/f719733e339544e9a330/?dl=1"),
]
def __init__(self, root, split='train', data_folder='images', label_folder='labels', **kwargs):
assert split in ['train']
# download meta information from Internet
list(map(lambda args: download_data(root, *args), self.download_list))
data_list_file = os.path.join(root, "image_list", "{}.txt".format(split))
self.split = split
super(GTA5, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, data_folder, label_folder,
id_to_train_id=Cityscapes.ID_TO_TRAIN_ID,
train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs)
================================================
FILE: tllib/vision/datasets/segmentation/segmentation_list.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from typing import Sequence, Optional, Dict, Callable
from PIL import Image
import tqdm
import numpy as np
from torch.utils import data
import torch
class SegmentationList(data.Dataset):
"""A generic Dataset class for domain adaptation in image segmentation
Args:
root (str): Root directory of dataset
classes (seq[str]): The names of all the classes
data_list_file (str): File to read the image list from.
label_list_file (str): File to read the label list from.
data_folder (str): Sub-directory of the image.
label_folder (str): Sub-directory of the label.
mean (seq[float]): mean BGR value. Normalize and convert to the image if not None. Default: None.
id_to_train_id (dict, optional): the map between the id on the label and the actual train id.
train_id_to_color (seq, optional): the map between the train id and the color.
transforms (callable, optional): A function/transform that takes in (PIL Image, label) pair \
and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.
.. note:: In ``data_list_file``, each line is the relative path of an image.
If your data_list_file has different formats, please over-ride :meth:`~SegmentationList.parse_data_file`.
::
source_dir/dog_xxx.png
target_dir/dog_xxy.png
In ``label_list_file``, each line is the relative path of an label.
If your label_list_file has different formats, please over-ride :meth:`~SegmentationList.parse_label_file`.
.. warning:: When mean is not None, please do not provide Normalize and ToTensor in transforms.
"""
def __init__(self, root: str, classes: Sequence[str], data_list_file: str, label_list_file: str,
data_folder: str, label_folder: str,
id_to_train_id: Optional[Dict] = None, train_id_to_color: Optional[Sequence] = None,
transforms: Optional[Callable] = None):
self.root = root
self.classes = classes
self.data_list_file = data_list_file
self.label_list_file = label_list_file
self.data_folder = data_folder
self.label_folder = label_folder
self.ignore_label = 255
self.id_to_train_id = id_to_train_id
self.train_id_to_color = np.array(train_id_to_color)
self.data_list = self.parse_data_file(self.data_list_file)
self.label_list = self.parse_label_file(self.label_list_file)
self.transforms = transforms
def parse_data_file(self, file_name):
"""Parse file to image list
Args:
file_name (str): The path of data file
Returns:
List of image path
"""
with open(file_name, "r") as f:
data_list = [line.strip() for line in f.readlines()]
return data_list
def parse_label_file(self, file_name):
"""Parse file to label list
Args:
file_name (str): The path of data file
Returns:
List of label path
"""
with open(file_name, "r") as f:
label_list = [line.strip() for line in f.readlines()]
return label_list
def __len__(self):
return len(self.data_list)
def __getitem__(self, index):
image_name = self.data_list[index]
label_name = self.label_list[index]
image = Image.open(os.path.join(self.root, self.data_folder, image_name)).convert('RGB')
label = Image.open(os.path.join(self.root, self.label_folder, label_name))
image, label = self.transforms(image, label)
# remap label
if isinstance(label, torch.Tensor):
label = label.numpy()
label = np.asarray(label, np.int64)
label_copy = self.ignore_label * np.ones(label.shape, dtype=np.int64)
if self.id_to_train_id:
for k, v in self.id_to_train_id.items():
label_copy[label == k] = v
return image, label_copy.copy()
@property
def num_classes(self) -> int:
"""Number of classes"""
return len(self.classes)
def decode_target(self, target):
""" Decode label (each value is integer) into the corresponding RGB value.
Args:
target (numpy.array): label in shape H x W
Returns:
RGB label (PIL Image) in shape H x W x 3
"""
target = target.copy()
target[target == 255] = self.num_classes # unknown label is black on the RGB label
target = self.train_id_to_color[target]
return Image.fromarray(target.astype(np.uint8))
def collect_image_paths(self):
"""Return a list of the absolute path of all the images"""
return [os.path.join(self.root, self.data_folder, image_name) for image_name in self.data_list]
@staticmethod
def _save_pil_image(image, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
image.save(path)
def translate(self, transform: Callable, target_root: str, color=False):
""" Translate an image and save it into a specified directory
Args:
transform (callable): a transform function that maps (image, label) pair from one domain to another domain
target_root (str): the root directory to save images and labels
"""
os.makedirs(target_root, exist_ok=True)
for image_name, label_name in zip(tqdm.tqdm(self.data_list), self.label_list):
image_path = os.path.join(target_root, self.data_folder, image_name)
label_path = os.path.join(target_root, self.label_folder, label_name)
if os.path.exists(image_path) and os.path.exists(label_path):
continue
image = Image.open(os.path.join(self.root, self.data_folder, image_name)).convert('RGB')
label = Image.open(os.path.join(self.root, self.label_folder, label_name))
translated_image, translated_label = transform(image, label)
self._save_pil_image(translated_image, image_path)
self._save_pil_image(translated_label, label_path)
if color:
colored_label = self.decode_target(np.array(translated_label))
file_name, file_ext = os.path.splitext(label_name)
self._save_pil_image(colored_label, os.path.join(target_root, self.label_folder,
"{}_color{}".format(file_name, file_ext)))
@property
def evaluate_classes(self):
"""The name of classes to be evaluated"""
return self.classes
@property
def ignore_classes(self):
"""The name of classes to be ignored"""
return list(set(self.classes) - set(self.evaluate_classes))
================================================
FILE: tllib/vision/datasets/segmentation/synthia.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from .segmentation_list import SegmentationList
from .cityscapes import Cityscapes
from .._util import download as download_data
class Synthia(SegmentationList):
"""`SYNTHIA `_
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``.
data_folder (str, optional): Sub-directory of the image. Default: 'RGB'.
label_folder (str, optional): Sub-directory of the label. Default: 'synthia_mapped_to_cityscapes'.
mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None.
transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \
and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`.
.. note:: You need to download GTA5 manually.
Ensure that there exist following directories in the `root` directory before you using this class.
::
RGB/
synthia_mapped_to_cityscapes/
"""
ID_TO_TRAIN_ID = {
3: 0, 4: 1, 2: 2, 21: 3, 5: 4, 7: 5,
15: 6, 9: 7, 6: 8, 16: 9, 1: 10, 10: 11, 17: 12,
8: 13, 18: 14, 19: 15, 20: 16, 12: 17, 11: 18
}
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/1c652d518e0347e2800d/?dl=1"),
]
def __init__(self, root, split='train', data_folder='RGB', label_folder='synthia_mapped_to_cityscapes', **kwargs):
assert split in ['train']
# download meta information from Internet
list(map(lambda args: download_data(root, *args), self.download_list))
data_list_file = os.path.join(root, "image_list", "{}.txt".format(split))
super(Synthia, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, data_folder,
label_folder, id_to_train_id=Synthia.ID_TO_TRAIN_ID,
train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs)
@property
def evaluate_classes(self):
return [
'road', 'sidewalk', 'building', 'traffic light', 'traffic sign',
'vegetation', 'sky', 'person', 'rider', 'car', 'bus', 'motorcycle', 'bicycle'
]
================================================
FILE: tllib/vision/datasets/stanford_cars.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class StanfordCars(ImageList):
"""`The Stanford Cars `_ \
contains 16,185 images of 196 classes of cars. \
Each category has been split roughly in a 50-50 split. \
There are 8,144 images for training and 8,041 images for testing.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
sample_rate (int): The sampling rates to sample random ``training`` images for each category.
Choices include 100, 50, 30, 15. Default: 100.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
train/
test/
image_list/
train_100.txt
train_50.txt
train_30.txt
train_15.txt
test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/aeeb690e9886442aa267/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/fd80c30c120a42a08fd3/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/01e6b279f20440cb8bf9/?dl=1"),
]
image_list = {
"train": "image_list/train_100.txt",
"train100": "image_list/train_100.txt",
"train50": "image_list/train_50.txt",
"train30": "image_list/train_30.txt",
"train15": "image_list/train_15.txt",
"test": "image_list/test.txt",
"test100": "image_list/test.txt",
}
CLASSES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19',
'20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36',
'37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53',
'54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70',
'71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87',
'88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103',
'104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118',
'119', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130', '131', '132', '133',
'134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148',
'149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163',
'164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178',
'179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '192', '193',
'194', '195', '196']
def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,
**kwargs):
if split == 'train':
list_name = 'train' + str(sample_rate)
assert list_name in self.image_list
data_list_file = os.path.join(root, self.image_list[list_name])
else:
data_list_file = os.path.join(root, self.image_list['test'])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(StanfordCars, self).__init__(root, StanfordCars.CLASSES, data_list_file=data_list_file, **kwargs)
================================================
FILE: tllib/vision/datasets/stanford_dogs.py
================================================
"""
@author: Yifei Ji
@contact: jiyf990330@163.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class StanfordDogs(ImageList):
"""`The Stanford Dogs `_ \
contains 20,580 images of 120 breeds of dogs from around the world. \
Each category is composed of exactly 100 training examples and around 72 testing examples.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
sample_rate (int): The sampling rates to sample random ``training`` images for each category.
Choices include 100, 50, 30, 15. Default: 100.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
train/
test/
image_list/
train_100.txt
train_50.txt
train_30.txt
train_15.txt
test.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/7685b13c549a4591b429/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/9f19a6d1b14b4f1e8d13/?dl=1"),
("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/a497b21e31cc4bfc9d45/?dl=1"),
]
image_list = {
"train": "image_list/train_100.txt",
"train100": "image_list/train_100.txt",
"train50": "image_list/train_50.txt",
"train30": "image_list/train_30.txt",
"train15": "image_list/train_15.txt",
"test": "image_list/test.txt",
"test100": "image_list/test.txt",
}
CLASSES = ['n02085620-Chihuahua', 'n02085782-Japanese_spaniel', 'n02085936-Maltese_dog', 'n02086079-Pekinese',
'n02086240-Shih-Tzu',
'n02086646-Blenheim_spaniel', 'n02086910-papillon', 'n02087046-toy_terrier',
'n02087394-Rhodesian_ridgeback',
'n02088094-Afghan_hound', 'n02088238-basset', 'n02088364-beagle', 'n02088466-bloodhound',
'n02088632-bluetick', 'n02089078-black-and-tan_coonhound',
'n02089867-Walker_hound', 'n02089973-English_foxhound', 'n02090379-redbone', 'n02090622-borzoi',
'n02090721-Irish_wolfhound', 'n02091032-Italian_greyhound',
'n02091134-whippet', 'n02091244-Ibizan_hound', 'n02091467-Norwegian_elkhound', 'n02091635-otterhound',
'n02091831-Saluki', 'n02092002-Scottish_deerhound',
'n02092339-Weimaraner', 'n02093256-Staffordshire_bullterrier',
'n02093428-American_Staffordshire_terrier', 'n02093647-Bedlington_terrier', 'n02093754-Border_terrier',
'n02093859-Kerry_blue_terrier', 'n02093991-Irish_terrier', 'n02094114-Norfolk_terrier',
'n02094258-Norwich_terrier', 'n02094433-Yorkshire_terrier',
'n02095314-wire-haired_fox_terrier', 'n02095570-Lakeland_terrier', 'n02095889-Sealyham_terrier',
'n02096051-Airedale', 'n02096177-cairn', 'n02096294-Australian_terrier',
'n02096437-Dandie_Dinmont', 'n02096585-Boston_bull', 'n02097047-miniature_schnauzer',
'n02097130-giant_schnauzer', 'n02097209-standard_schnauzer',
'n02097298-Scotch_terrier', 'n02097474-Tibetan_terrier', 'n02097658-silky_terrier',
'n02098105-soft-coated_wheaten_terrier', 'n02098286-West_Highland_white_terrier',
'n02098413-Lhasa', 'n02099267-flat-coated_retriever', 'n02099429-curly-coated_retriever',
'n02099601-golden_retriever', 'n02099712-Labrador_retriever',
'n02099849-Chesapeake_Bay_retriever', 'n02100236-German_short-haired_pointer', 'n02100583-vizsla',
'n02100735-English_setter', 'n02100877-Irish_setter',
'n02101006-Gordon_setter', 'n02101388-Brittany_spaniel', 'n02101556-clumber',
'n02102040-English_springer', 'n02102177-Welsh_springer_spaniel', 'n02102318-cocker_spaniel',
'n02102480-Sussex_spaniel', 'n02102973-Irish_water_spaniel', 'n02104029-kuvasz', 'n02104365-schipperke',
'n02105056-groenendael', 'n02105162-malinois', 'n02105251-briard', 'n02105412-kelpie',
'n02105505-komondor', 'n02105641-Old_English_sheepdog', 'n02105855-Shetland_sheepdog',
'n02106030-collie', 'n02106166-Border_collie', 'n02106382-Bouvier_des_Flandres', 'n02106550-Rottweiler',
'n02106662-German_shepherd', 'n02107142-Doberman', 'n02107312-miniature_pinscher',
'n02107574-Greater_Swiss_Mountain_dog',
'n02107683-Bernese_mountain_dog', 'n02107908-Appenzeller', 'n02108000-EntleBucher', 'n02108089-boxer',
'n02108422-bull_mastiff', 'n02108551-Tibetan_mastiff',
'n02108915-French_bulldog', 'n02109047-Great_Dane', 'n02109525-Saint_Bernard', 'n02109961-Eskimo_dog',
'n02110063-malamute', 'n02110185-Siberian_husky',
'n02110627-affenpinscher', 'n02110806-basenji', 'n02110958-pug', 'n02111129-Leonberg',
'n02111277-Newfoundland', 'n02111500-Great_Pyrenees', 'n02111889-Samoyed', 'n02112018-Pomeranian',
'n02112137-chow', 'n02112350-keeshond', 'n02112706-Brabancon_griffon', 'n02113023-Pembroke',
'n02113186-Cardigan',
'n02113624-toy_poodle', 'n02113712-miniature_poodle', 'n02113799-standard_poodle',
'n02113978-Mexican_hairless', 'n02115641-dingo', 'n02115913-dhole', 'n02116738-African_hunting_dog']
def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False,
**kwargs):
if split == 'train':
list_name = 'train' + str(sample_rate)
assert list_name in self.image_list
data_list_file = os.path.join(root, self.image_list[list_name])
else:
data_list_file = os.path.join(root, self.image_list['test'])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(StanfordDogs, self).__init__(root, StanfordDogs.CLASSES, data_list_file=data_list_file, **kwargs)
================================================
FILE: tllib/vision/datasets/sun397.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import os
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class SUN397(ImageList):
"""`SUN397 `_ is a dataset for scene understanding
with 108,754 images in 397 scene categories. The number of images varies across categories,
but there are at least 100 images per category. Note that the authors construct 10 partitions,
where each partition contains 50 training images and 50 testing images per class. We adopt partition 1.
Args:
root (str): Root directory of dataset
split (str, optional): The dataset split, supports ``train``, or ``test``.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, :class:`torchvision.transforms.RandomCrop`.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
"""
dataset_url = ("SUN397", "SUN397.tar.gz", "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz")
image_list_url = (
"SUN397/image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/dec0775147c144ea9f75/?dl=1")
def __init__(self, root, split='train', download=True, **kwargs):
if download:
download_data(root, *self.dataset_url)
download_data(os.path.join(root, 'SUN397'), *self.image_list_url)
else:
check_exits(root, "SUN397")
check_exits(root, "SUN397/image_list")
classes = list([str(i) for i in range(397)])
root = os.path.join(root, 'SUN397')
super(SUN397, self).__init__(root, classes, os.path.join(root, 'image_list', '{}.txt'.format(split)), **kwargs)
================================================
FILE: tllib/vision/datasets/visda2017.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import os
from typing import Optional
from .imagelist import ImageList
from ._util import download as download_data, check_exits
class VisDA2017(ImageList):
"""`VisDA-2017 `_ Dataset
Args:
root (str): Root directory of dataset
task (str): The task (domain) to create dataset. Choices include ``'Synthetic'``: synthetic images and \
``'Real'``: real-world images.
download (bool, optional): If true, downloads the dataset from the internet and puts it \
in root directory. If dataset is already downloaded, it is not downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image and returns a \
transformed version. E.g, ``transforms.RandomCrop``.
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
.. note:: In `root`, there will exist following files after downloading.
::
train/
aeroplance/
*.png
...
validation/
image_list/
train.txt
validation.txt
"""
download_list = [
("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/c107de37b8094c5398dc/?dl=1"),
("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/c5f3ce59139144ec8221/?dl=1"),
("validation", "validation.tgz", "https://cloud.tsinghua.edu.cn/f/da70e4b1cf514ecea562/?dl=1")
]
image_list = {
"Synthetic": "image_list/train.txt",
"Real": "image_list/validation.txt"
}
CLASSES = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife',
'motorcycle', 'person', 'plant', 'skateboard', 'train', 'truck']
def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs):
assert task in self.image_list
data_list_file = os.path.join(root, self.image_list[task])
if download:
list(map(lambda args: download_data(root, *args), self.download_list))
else:
list(map(lambda file_name, _: check_exits(root, file_name), self.download_list))
super(VisDA2017, self).__init__(root, VisDA2017.CLASSES, data_list_file=data_list_file, **kwargs)
@classmethod
def domains(cls):
return list(cls.image_list.keys())
================================================
FILE: tllib/vision/models/__init__.py
================================================
from .resnet import *
from .digits import *
__all__ = ['resnet', 'digits']
================================================
FILE: tllib/vision/models/digits.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
class LeNet(nn.Sequential):
def __init__(self, num_classes=10):
super(LeNet, self).__init__(
nn.Conv2d(1, 20, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Conv2d(20, 50, kernel_size=5),
nn.Dropout2d(p=0.5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Flatten(start_dim=1),
nn.Linear(50 * 4 * 4, 500),
nn.ReLU(),
nn.Dropout(p=0.5),
)
self.num_classes = num_classes
self.out_features = 500
def copy_head(self):
return nn.Linear(500, self.num_classes)
class DTN(nn.Sequential):
def __init__(self, num_classes=10):
super(DTN, self).__init__(
nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.Dropout2d(0.1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(128),
nn.Dropout2d(0.3),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(256),
nn.Dropout2d(0.5),
nn.ReLU(),
nn.Flatten(start_dim=1),
nn.Linear(256 * 4 * 4, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(),
)
self.num_classes = num_classes
self.out_features = 512
def copy_head(self):
return nn.Linear(512, self.num_classes)
def lenet(pretrained=False, **kwargs):
"""LeNet model from
`"Gradient-based learning applied to document recognition" `_
Args:
num_classes (int): number of classes. Default: 10
.. note::
The input image size must be 28 x 28.
"""
return LeNet(**kwargs)
def dtn(pretrained=False, **kwargs):
""" DTN model
Args:
num_classes (int): number of classes. Default: 10
.. note::
The input image size must be 32 x 32.
"""
return DTN(**kwargs)
================================================
FILE: tllib/vision/models/keypoint_detection/__init__.py
================================================
from .pose_resnet import *
from . import loss
__all__ = ['pose_resnet']
================================================
FILE: tllib/vision/models/keypoint_detection/loss.py
================================================
"""
Modified from https://github.com/microsoft/human-pose-estimation.pytorch
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
import torch.nn.functional as F
class JointsMSELoss(nn.Module):
"""
Typical MSE loss for keypoint detection.
Args:
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output. Default: ``'mean'``
Inputs:
- output (tensor): heatmap predictions
- target (tensor): heatmap labels
- target_weight (tensor): whether the keypoint is visible. All keypoint is visible if None. Default: None.
Shape:
- output: :math:`(minibatch, K, H, W)` where K means the number of keypoints,
H and W is the height and width of the heatmap respectively.
- target: :math:`(minibatch, K, H, W)`.
- target_weight: :math:`(minibatch, K)`.
- Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, K)`.
"""
def __init__(self, reduction='mean'):
super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='none')
self.reduction = reduction
def forward(self, output, target, target_weight=None):
B, K, _, _ = output.shape
heatmaps_pred = output.reshape((B, K, -1))
heatmaps_gt = target.reshape((B, K, -1))
loss = self.criterion(heatmaps_pred, heatmaps_gt) * 0.5
if target_weight is not None:
loss = loss * target_weight.view((B, K, 1))
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'none':
return loss.mean(dim=-1)
class JointsKLLoss(nn.Module):
"""
KL Divergence for keypoint detection proposed by
`Regressive Domain Adaptation for Unsupervised Keypoint Detection `_.
Args:
reduction (str, optional): Specifies the reduction to apply to the output:
``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,
``'mean'``: the sum of the output will be divided by the number of
elements in the output. Default: ``'mean'``
Inputs:
- output (tensor): heatmap predictions
- target (tensor): heatmap labels
- target_weight (tensor): whether the keypoint is visible. All keypoint is visible if None. Default: None.
Shape:
- output: :math:`(minibatch, K, H, W)` where K means the number of keypoints,
H and W is the height and width of the heatmap respectively.
- target: :math:`(minibatch, K, H, W)`.
- target_weight: :math:`(minibatch, K)`.
- Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, K)`.
"""
def __init__(self, reduction='mean', epsilon=0.):
super(JointsKLLoss, self).__init__()
self.criterion = nn.KLDivLoss(reduction='none')
self.reduction = reduction
self.epsilon = epsilon
def forward(self, output, target, target_weight=None):
B, K, _, _ = output.shape
heatmaps_pred = output.reshape((B, K, -1))
heatmaps_pred = F.log_softmax(heatmaps_pred, dim=-1)
heatmaps_gt = target.reshape((B, K, -1))
heatmaps_gt = heatmaps_gt + self.epsilon
heatmaps_gt = heatmaps_gt / heatmaps_gt.sum(dim=-1, keepdims=True)
loss = self.criterion(heatmaps_pred, heatmaps_gt).sum(dim=-1)
if target_weight is not None:
loss = loss * target_weight.view((B, K))
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'none':
return loss.mean(dim=-1)
================================================
FILE: tllib/vision/models/keypoint_detection/pose_resnet.py
================================================
"""
Modified from https://github.com/microsoft/human-pose-estimation.pytorch
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
from ..resnet import _resnet, Bottleneck
class Upsampling(nn.Sequential):
"""
3-layers deconvolution used in `Simple Baseline `_.
"""
def __init__(self, in_channel=2048, hidden_dims=(256, 256, 256), kernel_sizes=(4, 4, 4), bias=False):
assert len(hidden_dims) == len(kernel_sizes), \
'ERROR: len(hidden_dims) is different len(kernel_sizes)'
layers = []
for hidden_dim, kernel_size in zip(hidden_dims, kernel_sizes):
if kernel_size == 4:
padding = 1
output_padding = 0
elif kernel_size == 3:
padding = 1
output_padding = 1
elif kernel_size == 2:
padding = 0
output_padding = 0
else:
raise NotImplementedError("kernel_size is {}".format(kernel_size))
layers.append(
nn.ConvTranspose2d(
in_channels=in_channel,
out_channels=hidden_dim,
kernel_size=kernel_size,
stride=2,
padding=padding,
output_padding=output_padding,
bias=bias))
layers.append(nn.BatchNorm2d(hidden_dim))
layers.append(nn.ReLU(inplace=True))
in_channel = hidden_dim
super(Upsampling, self).__init__(*layers)
# init following Simple Baseline
for name, m in self.named_modules():
if isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, std=0.001)
if bias:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class PoseResNet(nn.Module):
"""
`Simple Baseline `_ for keypoint detection.
Args:
backbone (torch.nn.Module): Backbone to extract 2-d features from data
upsampling (torch.nn.Module): Layer to upsample image feature to heatmap size
feature_dim (int): The dimension of the features from upsampling layer.
num_keypoints (int): Number of keypoints
finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: False
"""
def __init__(self, backbone, upsampling, feature_dim, num_keypoints, finetune=False):
super(PoseResNet, self).__init__()
self.backbone = backbone
self.upsampling = upsampling
self.head = nn.Conv2d(in_channels=feature_dim, out_channels=num_keypoints, kernel_size=1, stride=1, padding=0)
self.finetune = finetune
for m in self.head.modules():
nn.init.normal_(m.weight, std=0.001)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.backbone(x)
x = self.upsampling(x)
x = self.head(x)
return x
def get_parameters(self, lr=1.):
return [
{'params': self.backbone.parameters(), 'lr': 0.1 * lr if self.finetune else lr},
{'params': self.upsampling.parameters(), 'lr': lr},
{'params': self.head.parameters(), 'lr': lr},
]
def _pose_resnet(arch, num_keypoints, block, layers, pretrained_backbone, deconv_with_bias, finetune=False, progress=True, **kwargs):
backbone = _resnet(arch, block, layers, pretrained_backbone, progress, **kwargs)
upsampling = Upsampling(backbone.out_features, bias=deconv_with_bias)
model = PoseResNet(backbone, upsampling, 256, num_keypoints, finetune)
return model
def pose_resnet101(num_keypoints, pretrained_backbone=True, deconv_with_bias=False, finetune=False, progress=True, **kwargs):
"""Constructs a Simple Baseline model with a ResNet-101 backbone.
Args:
num_keypoints (int): number of keypoints
pretrained_backbone (bool, optional): If True, returns a model pre-trained on ImageNet. Default: True.
deconv_with_bias (bool, optional): Whether use bias in the deconvolution layer. Default: False
finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: False
progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True
"""
return _pose_resnet('resnet101', num_keypoints, Bottleneck, [3, 4, 23, 3], pretrained_backbone, deconv_with_bias, finetune, progress, **kwargs)
================================================
FILE: tllib/vision/models/object_detection/__init__.py
================================================
from . import meta_arch
from . import roi_heads
from . import proposal_generator
from . import backbone
================================================
FILE: tllib/vision/models/object_detection/backbone/__init__.py
================================================
from .vgg import VGG, build_vgg_fpn_backbone
================================================
FILE: tllib/vision/models/object_detection/backbone/mmdetection/vgg.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
# Source: https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/vgg.py
from mmcv.runner import load_checkpoint
import torch.nn as nn
from .weight_init import constant_init, kaiming_init, normal_init
def conv3x3(in_planes, out_planes, dilation=1):
"3x3 convolution with padding"
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
padding=dilation,
dilation=dilation)
def make_vgg_layer(inplanes,
planes,
num_blocks,
dilation=1,
with_bn=False,
ceil_mode=False):
layers = []
for _ in range(num_blocks):
layers.append(conv3x3(inplanes, planes, dilation))
if with_bn:
layers.append(nn.BatchNorm2d(planes))
layers.append(nn.ReLU(inplace=True))
inplanes = planes
layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
return layers
class VGG(nn.Module):
"""VGG backbone.
Args:
depth (int): Depth of vgg, from {11, 13, 16, 19}.
with_bn (bool): Use BatchNorm or not.
num_classes (int): number of classes for classification.
num_stages (int): VGG stages, normally 5.
dilations (Sequence[int]): Dilation of each stage.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
"""
arch_settings = {
11: (1, 1, 2, 2, 2),
13: (2, 2, 2, 2, 2),
16: (2, 2, 3, 3, 3),
19: (2, 2, 4, 4, 4)
}
def __init__(self,
depth,
with_bn=False,
num_classes=-1,
num_stages=5,
dilations=(1, 1, 1, 1, 1),
out_indices=(0, 1, 2, 3, 4),
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
ceil_mode=False,
with_last_pool=True):
super(VGG, self).__init__()
if depth not in self.arch_settings:
raise KeyError('invalid depth {} for vgg'.format(depth))
assert num_stages >= 1 and num_stages <= 5
stage_blocks = self.arch_settings[depth]
self.stage_blocks = stage_blocks[:num_stages]
assert len(dilations) == num_stages
assert max(out_indices) <= num_stages
self.num_classes = num_classes
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.inplanes = 3
start_idx = 0
vgg_layers = []
self.range_sub_modules = []
for i, num_blocks in enumerate(self.stage_blocks):
num_modules = num_blocks * (2 + with_bn) + 1
end_idx = start_idx + num_modules
dilation = dilations[i]
planes = 64 * 2**i if i < 4 else 512
vgg_layer = make_vgg_layer(
self.inplanes,
planes,
num_blocks,
dilation=dilation,
with_bn=with_bn,
ceil_mode=ceil_mode)
vgg_layers.extend(vgg_layer)
self.inplanes = planes
self.range_sub_modules.append([start_idx, end_idx])
start_idx = end_idx
if not with_last_pool:
vgg_layers.pop(-1)
self.range_sub_modules[-1][1] -= 1
self.module_name = 'features'
self.add_module(self.module_name, nn.Sequential(*vgg_layers))
if self.num_classes > 0:
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
# initialize the model by random
self.init_weights()
# Optionally freeze (requires_grad=False) parts of the backbone
self._freeze_backbone(self.frozen_stages)
def _freeze_backbone(self, freeze_at):
if freeze_at < 0:
return
vgg_layers = getattr(self, self.module_name)
for i in range(freeze_at):
for j in range(*self.range_sub_modules[i]):
mod = vgg_layers[j]
mod.eval()
for param in mod.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
load_checkpoint(self, pretrained, strict=False)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
elif isinstance(m, nn.Linear):
normal_init(m, std=0.01)
else:
raise TypeError('pretrained must be a str or None')
def forward(self, x):
outs = []
vgg_layers = getattr(self, self.module_name)
for i, num_blocks in enumerate(self.stage_blocks):
for j in range(*self.range_sub_modules[i]):
vgg_layer = vgg_layers[j]
x = vgg_layer(x)
if i in self.out_indices:
outs.append(x)
if self.num_classes > 0:
x = x.view(x.size(0), -1)
x = self.classifier(x)
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
================================================
FILE: tllib/vision/models/object_detection/backbone/mmdetection/weight_init.py
================================================
# Copyright (c) Open-MMLab. All rights reserved.
# Source: https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/weight_init.py
import numpy as np
import torch.nn as nn
def constant_init(module, val, bias=0):
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(
module,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform')
================================================
FILE: tllib/vision/models/object_detection/backbone/vgg.py
================================================
# referece from https://github.com/chengchunhsu/EveryPixelMatters
import torch
import torch.nn.functional as F
from torch import nn
from detectron2.modeling.backbone import Backbone, BACKBONE_REGISTRY
from .mmdetection.vgg import VGG
class FPN(nn.Module):
"""
Module that adds FPN on top of a list of feature maps.
The feature maps are currently supposed to be in increasing depth
order, and must be consecutive
"""
def __init__(
self, in_channels_list, out_channels, conv_block, top_blocks=None
):
"""
Arguments:
in_channels_list (list[int]): number of channels for each feature map that
will be fed
out_channels (int): number of channels of the FPN representation
top_blocks (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
FPN output, and the result will extend the result list
"""
super(FPN, self).__init__()
self.inner_blocks = []
self.layer_blocks = []
for idx, in_channels in enumerate(in_channels_list, 1):
inner_block = "fpn_inner{}".format(idx)
layer_block = "fpn_layer{}".format(idx)
if in_channels == 0:
continue
inner_block_module = conv_block(in_channels, out_channels, 1)
layer_block_module = conv_block(out_channels, out_channels, 3, 1)
self.add_module(inner_block, inner_block_module)
self.add_module(layer_block, layer_block_module)
self.inner_blocks.append(inner_block)
self.layer_blocks.append(layer_block)
self.top_blocks = top_blocks
def forward(self, x):
"""
Arguments:
x (list[Tensor]): feature maps for each feature level.
Returns:
results (tuple[Tensor]): feature maps after FPN layers.
They are ordered from highest resolution first.
"""
last_inner = getattr(self, self.inner_blocks[-1])(x[-1])
results = []
results.append(getattr(self, self.layer_blocks[-1])(last_inner))
for feature, inner_block, layer_block in zip(
x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1]
):
if not inner_block:
continue
# inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest")
inner_lateral = getattr(self, inner_block)(feature)
# TODO use size instead of scale to make it robust to different sizes
inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:],
mode='bilinear', align_corners=False)
last_inner = inner_lateral + inner_top_down
results.insert(0, getattr(self, layer_block)(last_inner))
if isinstance(self.top_blocks, LastLevelP6P7):
last_results = self.top_blocks(x[-1], results[-1])
results.extend(last_results)
elif isinstance(self.top_blocks, LastLevelMaxPool):
last_results = self.top_blocks(results[-1])
results.extend(last_results)
return tuple(results)
class LastLevelMaxPool(nn.Module):
def forward(self, x):
return [F.max_pool2d(x, 1, 2, 0)]
class LastLevelP6P7(nn.Module):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7.
"""
def __init__(self, in_channels, out_channels):
super(LastLevelP6P7, self).__init__()
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
nn.init.kaiming_uniform_(module.weight, a=1)
nn.init.constant_(module.bias, 0)
self.use_P5 = in_channels == out_channels
def forward(self, c5, p5):
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
p7 = self.p7(F.relu(p6))
return [p6, p7]
class _NewEmptyTensorOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, new_shape):
ctx.shape = x.shape
return x.new_empty(new_shape)
@staticmethod
def backward(ctx, grad):
shape = ctx.shape
return _NewEmptyTensorOp.apply(grad, shape), None
class Conv2d(torch.nn.Conv2d):
def forward(self, x):
if x.numel() > 0:
return super(Conv2d, self).forward(x)
# get output shape
output_shape = [
(i + 2 * p - (di * (k - 1) + 1)) // d + 1
for i, p, di, k, d in zip(
x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
)
]
output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
return _NewEmptyTensorOp.apply(x, output_shape)
def conv_with_kaiming_uniform():
def make_conv(
in_channels, out_channels, kernel_size, stride=1, dilation=1
):
conv = Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=dilation * (kernel_size - 1) // 2,
dilation=dilation,
bias=True
)
# Caffe2 implementation uses XavierFill, which in fact
# corresponds to kaiming_uniform_ in PyTorch
nn.init.kaiming_uniform_(conv.weight, a=1)
nn.init.constant_(conv.bias, 0)
module = [conv,]
if len(module) > 1:
return nn.Sequential(*module)
return conv
return make_conv
class VGGFPN(Backbone):
def __init__(self, body, fpn):
super(VGGFPN, self).__init__()
self.body = body
self.fpn = fpn
self._out_features = ["p3", "p4", "p5", "p6", "p7"]
self._out_feature_channels = {
"p3": 256, "p4": 256, "p5": 256, "p6": 256, "p7": 256
}
self._out_feature_strides = {
"p3": 8, "p4": 16, "p5": 32, "p6": 64, "p7": 128
}
def forward(self, x):
# print(x.shape)
f = self.body(x)
f = self.fpn(f)
return {
name: feature for name, feature in zip(self._out_features, f)
}
@BACKBONE_REGISTRY.register()
def build_vgg_fpn_backbone(cfg, input_shape):
body = VGG(depth=16, with_last_pool=True, frozen_stages=2)
body.init_weights(cfg.MODEL.WEIGHTS)
in_channels_stage2 = 128 # default: cfg.MODEL.RESNETS.RES2_OUT_CHANNELS (256)
out_channels = 256 # default: cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS (256)
in_channels_p6p7 = out_channels
fpn = FPN(
in_channels_list=[
0,
0,
in_channels_stage2 * 2,
in_channels_stage2 * 4,
in_channels_stage2 * 4, # in_channels_stage2 * 8
],
out_channels=out_channels,
conv_block=conv_with_kaiming_uniform(),
top_blocks=LastLevelP6P7(in_channels_p6p7, out_channels),
)
model = VGGFPN(body, fpn)
model.out_channels = out_channels
return model
================================================
FILE: tllib/vision/models/object_detection/meta_arch/__init__.py
================================================
from .rcnn import TLGeneralizedRCNN
from .retinanet import TLRetinaNet
================================================
FILE: tllib/vision/models/object_detection/meta_arch/rcnn.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Tuple, Dict
import torch
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN as GeneralizedRCNNBase, get_event_storage
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class TLGeneralizedRCNN(GeneralizedRCNNBase):
"""
Generalized R-CNN for Transfer Learning.
Similar to that in in Supervised Learning, TLGeneralizedRCNN has the following three components:
1. Per-image feature extraction (aka backbone)
2. Region proposal generation
3. Per-region feature extraction and prediction
Different from that in Supervised Learning, TLGeneralizedRCNN
1. accepts unlabeled images during training (return no losses)
2. return both detection outputs, features, and losses during training
Args:
backbone: a backbone module, must follow detectron2's backbone interface
proposal_generator: a module that generates proposals using backbone features
roi_heads: a ROI head that performs per-region computation
pixel_mean, pixel_std: list or tuple with #channels element,
representing the per-channel mean and std to be used to normalize
the input image
input_format: describe the meaning of channels of input. Needed by visualization
vis_period: the period to run visualization. Set to 0 to disable.
finetune (bool): whether finetune the detector or train from scratch. Default: True
Inputs:
- batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances (optional): groundtruth :class:`Instances`
* proposals (optional): :class:`Instances`, precomputed proposals.
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
- labeled (bool, optional): whether has ground-truth label
Outputs:
- outputs: A list of dict where each dict is the output for one input image.
The dict contains a key "instances" whose value is a :class:`Instances`
and a key "features" whose value is the features of middle layers.
The :class:`Instances` object has the following keys:
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
- losses: A dict of different losses
"""
def __init__(self, *args, finetune=False, **kwargs):
super().__init__(*args, **kwargs)
self.finetune = finetune
def forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]], labeled=True):
""""""
if not self.training:
return self.inference(batched_inputs)
images = self.preprocess_image(batched_inputs)
if "instances" in batched_inputs[0] and labeled:
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
else:
gt_instances = None
features = self.backbone(images.tensor)
if self.proposal_generator is not None:
proposals, proposal_losses = self.proposal_generator(images, features, gt_instances, labeled)
else:
assert "proposals" in batched_inputs[0]
proposals = [x["proposals"].to(self.device) for x in batched_inputs]
proposal_losses = {}
outputs, detector_losses = self.roi_heads(images, features, proposals, gt_instances, labeled)
if self.vis_period > 0:
storage = get_event_storage()
if storage.iter % self.vis_period == 0:
self.visualize_training(batched_inputs, proposals)
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
outputs['features'] = features
return outputs, losses
def get_parameters(self, lr=1.):
"""Return a parameter list which decides optimization hyper-parameters,
such as the learning rate of each layer
"""
return [
(self.backbone, 0.1 * lr if self.finetune else lr),
(self.proposal_generator, lr),
(self.roi_heads, lr),
]
================================================
FILE: tllib/vision/models/object_detection/meta_arch/retinanet.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Dict, List, Tuple
import torch
from torch import Tensor, nn
from detectron2.modeling.meta_arch.retinanet import RetinaNet as RetinaNetBase
from detectron2.modeling import detector_postprocess
class TLRetinaNet(RetinaNetBase):
"""
RetinaNet for Transfer Learning.
Different from that in Supervised Learning, TLRetinaNet
1. accepts unlabeled images during training (return no losses)
2. return both detection outputs, features, and losses during training
Args:
backbone: a backbone module, must follow detectron2's backbone interface
head (nn.Module): a module that predicts logits and regression deltas
for each level from a list of per-level features
head_in_features (Tuple[str]): Names of the input feature maps to be used in head
anchor_generator (nn.Module): a module that creates anchors from a
list of features. Usually an instance of :class:`AnchorGenerator`
box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to
instance boxes
anchor_matcher (Matcher): label the anchors by matching them with ground truth.
num_classes (int): number of classes. Used to label background proposals.
# Loss parameters:
focal_loss_alpha (float): focal_loss_alpha
focal_loss_gamma (float): focal_loss_gamma
smooth_l1_beta (float): smooth_l1_beta
box_reg_loss_type (str): Options are "smooth_l1", "giou"
# Inference parameters:
test_score_thresh (float): Inference cls score threshold, only anchors with
score > INFERENCE_TH are considered for inference (to improve speed)
test_topk_candidates (int): Select topk candidates before NMS
test_nms_thresh (float): Overlap threshold used for non-maximum suppression
(suppress boxes with IoU >= this threshold)
max_detections_per_image (int):
Maximum number of detections to return per image during inference
(100 is based on the limit established for the COCO dataset).
# Input parameters
pixel_mean (Tuple[float]):
Values to be used for image normalization (BGR order).
To train on images of different number of channels, set different mean & std.
Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
pixel_std (Tuple[float]):
When using pre-trained models in Detectron1 or any MSRA models,
std has been absorbed into its conv1 weights, so the std needs to be set 1.
Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
vis_period (int):
The period (in terms of steps) for minibatch visualization at train time.
Set to 0 to disable.
input_format (str): Whether the model needs RGB, YUV, HSV etc.
finetune (bool): whether finetune the detector or train from scratch. Default: True
Inputs:
- batched_inputs: a list, batched outputs of :class:`DatasetMapper`.
Each item in the list contains the inputs for one image.
For now, each item in the list is a dict that contains:
* image: Tensor, image in (C, H, W) format.
* instances (optional): groundtruth :class:`Instances`
* "height", "width" (int): the output resolution of the model, used in inference.
See :meth:`postprocess` for details.
- labeled (bool, optional): whether has ground-truth label
Outputs:
- outputs: A list of dict where each dict is the output for one input image.
The dict contains a key "instances" whose value is a :class:`Instances`
and a key "features" whose value is the features of middle layers.
The :class:`Instances` object has the following keys:
"pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints"
- losses: A dict of different losses
"""
def __init__(self, *args, finetune=False, **kwargs):
super().__init__(*args, **kwargs)
self.finetune = finetune
def forward(self, batched_inputs: Tuple[Dict[str, Tensor]], labeled=True):
""""""
images = self.preprocess_image(batched_inputs)
features = self.backbone(images.tensor)
features = [features[f] for f in self.head_in_features]
predictions = self.head(features)
if self.training:
if labeled:
assert not torch.jit.is_scripting(), "Not supported"
assert "instances" in batched_inputs[0], "Instance annotations are missing in training!"
gt_instances = [x["instances"].to(self.device) for x in batched_inputs]
losses = self.forward_training(images, features, predictions, gt_instances)
else:
losses = {}
outputs = {"features": features}
return outputs, losses
else:
results = self.forward_inference(images, features, predictions)
if torch.jit.is_scripting():
return results
processed_results = []
for results_per_image, input_per_image, image_size in zip(
results, batched_inputs, images.image_sizes
):
height = input_per_image.get("height", image_size[0])
width = input_per_image.get("width", image_size[1])
r = detector_postprocess(results_per_image, height, width)
processed_results.append({"instances": r})
return processed_results
def get_parameters(self, lr=1.):
"""Return a parameter list which decides optimization hyper-parameters,
such as the learning rate of each layer
"""
return [
(self.backbone.bottom_up, 0.1 * lr if self.finetune else lr),
(self.backbone.fpn_lateral4, lr),
(self.backbone.fpn_output4, lr),
(self.backbone.fpn_lateral5, lr),
(self.backbone.fpn_output5, lr),
(self.backbone.top_block, lr),
(self.head, lr),
]
================================================
FILE: tllib/vision/models/object_detection/proposal_generator/__init__.py
================================================
from .rpn import TLRPN
================================================
FILE: tllib/vision/models/object_detection/proposal_generator/rpn.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from typing import Dict, Optional, List
import torch
from detectron2.structures import ImageList, Instances
from detectron2.modeling.proposal_generator import (
RPN,
PROPOSAL_GENERATOR_REGISTRY,
)
@PROPOSAL_GENERATOR_REGISTRY.register()
class TLRPN(RPN):
"""
Region Proposal Network, introduced by `Faster R-CNN`.
Args:
in_features (list[str]): list of names of input features to use
head (nn.Module): a module that predicts logits and regression deltas
for each level from a list of per-level features
anchor_generator (nn.Module): a module that creates anchors from a
list of features. Usually an instance of :class:`AnchorGenerator`
anchor_matcher (Matcher): label the anchors by matching them with ground truth.
box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to
instance boxes
batch_size_per_image (int): number of anchors per image to sample for training
positive_fraction (float): fraction of foreground anchors to sample for training
pre_nms_topk (tuple[float]): (train, test) that represents the
number of top k proposals to select before NMS, in
training and testing.
post_nms_topk (tuple[float]): (train, test) that represents the
number of top k proposals to select after NMS, in
training and testing.
nms_thresh (float): NMS threshold used to de-duplicate the predicted proposals
min_box_size (float): remove proposal boxes with any side smaller than this threshold,
in the unit of input image pixels
anchor_boundary_thresh (float): legacy option
loss_weight (float|dict): weights to use for losses. Can be single float for weighting
all rpn losses together, or a dict of individual weightings. Valid dict keys are:
"loss_rpn_cls" - applied to classification loss
"loss_rpn_loc" - applied to box regression loss
box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou".
smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to
use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1"
Inputs:
- images (ImageList): input images of length `N`
- features (dict[str, Tensor]): input data as a mapping from feature
map name to tensor. Axis 0 represents the number of images `N` in
the input data; axes 1-3 are channels, height, and width, which may
vary between feature maps (e.g., if a feature pyramid is used).
- gt_instances (list[Instances], optional): a length `N` list of `Instances`s.
Each `Instances` stores ground-truth instances for the corresponding image.
- labeled (bool, optional): whether has ground-truth label. Default: True
Outputs:
- proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits"
- loss: dict[Tensor] or None
"""
def __init__(self, *args, **kwargs):
super(TLRPN, self).__init__(*args, **kwargs)
def forward(
self,
images: ImageList,
features: Dict[str, torch.Tensor],
gt_instances: Optional[List[Instances]] = None,
labeled: Optional[bool] = True
):
features = [features[f] for f in self.in_features]
# print(torch.max(features[0]))
anchors = self.anchor_generator(features)
pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features)
# Transpose the Hi*Wi*A dimension to the middle:
pred_objectness_logits = [
# (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A)
score.permute(0, 2, 3, 1).flatten(1)
for score in pred_objectness_logits
]
pred_anchor_deltas = [
# (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B)
x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1])
.permute(0, 3, 4, 1, 2)
.flatten(1, -2)
for x in pred_anchor_deltas
]
if self.training and labeled:
gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances)
losses = self.losses(
anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes
)
else:
losses = {}
proposals = self.predict_proposals(
anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
)
return proposals, losses
================================================
FILE: tllib/vision/models/object_detection/roi_heads/__init__.py
================================================
from .roi_heads import TLRes5ROIHeads, TLStandardROIHeads
================================================
FILE: tllib/vision/models/object_detection/roi_heads/roi_heads.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch
from typing import List, Dict
from detectron2.structures import Instances
from detectron2.modeling.roi_heads import (
ROI_HEADS_REGISTRY,
Res5ROIHeads,
StandardROIHeads,
select_foreground_proposals,
)
@ROI_HEADS_REGISTRY.register()
class TLRes5ROIHeads(Res5ROIHeads):
"""
The ROIHeads in a typical "C4" R-CNN model, where
the box and mask head share the cropping and
the per-region feature computation by a Res5 block.
Args:
in_features (list[str]): list of backbone feature map names to use for
feature extraction
pooler (ROIPooler): pooler to extra region features from backbone
res5 (nn.Sequential): a CNN to compute per-region features, to be used by
``box_predictor`` and ``mask_head``. Typically this is a "res5"
block from a ResNet.
box_predictor (nn.Module): make box predictions from the feature.
Should have the same interface as :class:`FastRCNNOutputLayers`.
mask_head (nn.Module): transform features to make mask predictions
Inputs:
- images (ImageList):
- features (dict[str,Tensor]): input data as a mapping from feature
map name to tensor. Axis 0 represents the number of images `N` in
the input data; axes 1-3 are channels, height, and width, which may
vary between feature maps (e.g., if a feature pyramid is used).
- proposals (list[Instances]): length `N` list of `Instances`. The i-th
`Instances` contains object proposals for the i-th input image,
with fields "proposal_boxes" and "objectness_logits".
- targets (list[Instances], optional): length `N` list of `Instances`. The i-th
`Instances` contains the ground-truth per-instance annotations
for the i-th input image. Specify `targets` during training only.
It may have the following fields:
- gt_boxes: the bounding box of each instance.
- gt_classes: the label for each instance with a category ranging in [0, #class].
- gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.
- gt_keypoints: NxKx3, the groud-truth keypoints for each instance.
- labeled (bool, optional): whether has ground-truth label. Default: True
Outputs:
- list[Instances]: length `N` list of `Instances` containing the
detected instances. Returned during inference only; may be [] during training.
- dict[str->Tensor]:
mapping from a named loss to a tensor storing the loss. Used during training only.
"""
def __init__(self, *args, **kwargs):
super(TLRes5ROIHeads, self).__init__(*args, **kwargs)
def forward(self, images, features, proposals, targets=None, labeled=True):
""""""
del images
if self.training:
if labeled:
proposals = self.label_and_sample_proposals(proposals, targets)
else:
proposals = self.sample_unlabeled_proposals(proposals)
del targets
proposal_boxes = [x.proposal_boxes for x in proposals]
box_features = self._shared_roi_transform(
[features[f] for f in self.in_features], proposal_boxes
)
predictions = self.box_predictor(box_features.mean(dim=[2, 3]))
if self.training:
del features
if labeled:
losses = self.box_predictor.losses(predictions, proposals)
if self.mask_on:
proposals, fg_selection_masks = select_foreground_proposals(
proposals, self.num_classes
)
# Since the ROI feature transform is shared between boxes and masks,
# we don't need to recompute features. The mask loss is only defined
# on foreground proposals, so we need to select out the foreground
# features.
mask_features = box_features[torch.cat(fg_selection_masks, dim=0)]
# del box_features
losses.update(self.mask_head(mask_features, proposals))
else:
losses = {}
outputs = {
'predictions': predictions[0],
'box_features': box_features
}
return outputs, losses
else:
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
pred_instances = self.forward_with_given_boxes(features, pred_instances)
return pred_instances, {}
@torch.no_grad()
def sample_unlabeled_proposals(
self, proposals: List[Instances]
) -> List[Instances]:
"""
Prepare some unlabeled proposals.
It returns top ``self.batch_size_per_image`` samples from proposals
Args:
proposals (list[Instances]): length `N` list of `Instances`. The i-th
`Instances` contains object proposals for the i-th input image,
with fields "proposal_boxes" and "objectness_logits".
Returns:
length `N` list of `Instances`s containing the proposals sampled for training.
"""
return [proposal[:self.batch_size_per_image] for proposal in proposals]
@ROI_HEADS_REGISTRY.register()
class TLStandardROIHeads(StandardROIHeads):
"""
It's "standard" in a sense that there is no ROI transform sharing
or feature sharing between tasks.
Each head independently processes the input features by each head's
own pooler and head.
Args:
box_in_features (list[str]): list of feature names to use for the box head.
box_pooler (ROIPooler): pooler to extra region features for box head
box_head (nn.Module): transform features to make box predictions
box_predictor (nn.Module): make box predictions from the feature.
Should have the same interface as :class:`FastRCNNOutputLayers`.
mask_in_features (list[str]): list of feature names to use for the mask
pooler or mask head. None if not using mask head.
mask_pooler (ROIPooler): pooler to extract region features from image features.
The mask head will then take region features to make predictions.
If None, the mask head will directly take the dict of image features
defined by `mask_in_features`
mask_head (nn.Module): transform features to make mask predictions
keypoint_in_features, keypoint_pooler, keypoint_head: similar to ``mask_*``.
train_on_pred_boxes (bool): whether to use proposal boxes or
predicted boxes from the box head to train other heads.
Inputs:
- images (ImageList):
- features (dict[str,Tensor]): input data as a mapping from feature
map name to tensor. Axis 0 represents the number of images `N` in
the input data; axes 1-3 are channels, height, and width, which may
vary between feature maps (e.g., if a feature pyramid is used).
- proposals (list[Instances]): length `N` list of `Instances`. The i-th
`Instances` contains object proposals for the i-th input image,
with fields "proposal_boxes" and "objectness_logits".
- targets (list[Instances], optional): length `N` list of `Instances`. The i-th
`Instances` contains the ground-truth per-instance annotations
for the i-th input image. Specify `targets` during training only.
It may have the following fields:
- gt_boxes: the bounding box of each instance.
- gt_classes: the label for each instance with a category ranging in [0, #class].
- gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance.
- gt_keypoints: NxKx3, the groud-truth keypoints for each instance.
- labeled (bool, optional): whether has ground-truth label. Default: True
Outputs:
- list[Instances]: length `N` list of `Instances` containing the
detected instances. Returned during inference only; may be [] during training.
- dict[str->Tensor]:
mapping from a named loss to a tensor storing the loss. Used during training only.
"""
def __init__(self, *args, **kwargs):
super(TLStandardROIHeads, self).__init__(*args, **kwargs)
def forward(self, images, features, proposals, targets=None, labeled=True):
""""""
del images
if self.training:
if labeled:
proposals = self.label_and_sample_proposals(proposals, targets)
else:
proposals = self.sample_unlabeled_proposals(proposals)
del targets
if self.training:
if labeled:
outputs, losses = self._forward_box(features, proposals)
# Usually the original proposals used by the box head are used by the mask, keypoint
# heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes
# predicted by the box head.
losses.update(self._forward_mask(features, proposals))
losses.update(self._forward_keypoint(features, proposals))
else:
losses = {}
return outputs, losses
else:
pred_instances = self._forward_box(features, proposals)
# During inference cascaded prediction is used: the mask and keypoints heads are only
# applied to the top scoring box detections.
pred_instances = self.forward_with_given_boxes(features, pred_instances)
return pred_instances, {}
def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]):
"""
Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`,
the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument.
Args:
features (dict[str, Tensor]): mapping from feature map names to tensor.
Same as in :meth:`ROIHeads.forward`.
proposals (list[Instances]): the per-image object proposals with
their matching ground truth.
Each has fields "proposal_boxes", and "objectness_logits",
"gt_classes", "gt_boxes".
Returns:
In training, a dict of losses.
In inference, a list of `Instances`, the predicted instances.
"""
features = [features[f] for f in self.box_in_features]
box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals])
box_features = self.box_head(box_features)
predictions = self.box_predictor(box_features)
if self.training:
losses = self.box_predictor.losses(predictions, proposals)
# proposals is modified in-place below, so losses must be computed first.
if self.train_on_pred_boxes:
with torch.no_grad():
pred_boxes = self.box_predictor.predict_boxes_for_gt_classes(
predictions, proposals
)
for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes):
proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image)
outputs = {
'predictions': predictions[0],
'box_features': box_features
}
return outputs, losses
else:
pred_instances, _ = self.box_predictor.inference(predictions, proposals)
return pred_instances
@torch.no_grad()
def sample_unlabeled_proposals(
self, proposals: List[Instances]
) -> List[Instances]:
"""
Prepare some unlabeled proposals.
It returns top ``self.batch_size_per_image`` samples from proposals
Args:
proposals (list[Instances]): length `N` list of `Instances`. The i-th
`Instances` contains object proposals for the i-th input image,
with fields "proposal_boxes" and "objectness_logits".
Returns:
length `N` list of `Instances`s containing the proposals sampled for training.
"""
return [proposal[:self.batch_size_per_image] for proposal in proposals]
================================================
FILE: tllib/vision/models/reid/__init__.py
================================================
from .resnet import *
__all__ = ['resnet']
================================================
FILE: tllib/vision/models/reid/identifier.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from typing import List, Dict, Optional
import torch
import torch.nn as nn
from torch.nn import init
class ReIdentifier(nn.Module):
r"""Person reIdentifier from `Bag of Tricks and A Strong Baseline for Deep Person Re-identification (CVPR 2019)
`_.
Given 2-d features :math:`f` from backbone network, the authors pass :math:`f` through another `BatchNorm1d` layer
and get :math:`bn\_f`, which will then pass through a `Linear` layer to output predictions. During training, we
use :math:`f` to compute triplet loss. While during testing, :math:`bn\_f` is used as feature. This may be a little
confusing. The figures in the origin paper will help you understand better.
"""
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None,
bottleneck_dim: Optional[int] = -1, finetune=True, pool_layer=None):
super(ReIdentifier, self).__init__()
if pool_layer is None:
self.pool_layer = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
else:
self.pool_layer = pool_layer
self.backbone = backbone
self.num_classes = num_classes
if bottleneck is None:
feature_bn = nn.BatchNorm1d(backbone.out_features)
self.bottleneck = feature_bn
self._features_dim = backbone.out_features
else:
feature_bn = nn.BatchNorm1d(bottleneck_dim)
self.bottleneck = nn.Sequential(
bottleneck,
feature_bn
)
self._features_dim = bottleneck_dim
self.head = nn.Linear(self.features_dim, num_classes, bias=False)
self.finetune = finetune
# initialize feature_bn and head
feature_bn.bias.requires_grad_(False)
init.constant_(feature_bn.weight, 1)
init.constant_(feature_bn.bias, 0)
init.normal_(self.head.weight, std=0.001)
@property
def features_dim(self) -> int:
"""The dimension of features before the final `head` layer"""
return self._features_dim
def forward(self, x: torch.Tensor):
""""""
f = self.pool_layer(self.backbone(x))
bn_f = self.bottleneck(f)
if not self.training:
return bn_f
predictions = self.head(bn_f)
return predictions, f
def get_parameters(self, base_lr=1.0, rate=0.1) -> List[Dict]:
"""A parameter list which decides optimization hyper-parameters,
such as the relative learning rate of each layer
"""
params = [
{"params": self.backbone.parameters(), "lr": rate * base_lr if self.finetune else 1.0 * base_lr},
{"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr},
{"params": self.head.parameters(), "lr": 1.0 * base_lr},
]
return params
================================================
FILE: tllib/vision/models/reid/loss.py
================================================
"""
Modified from https://github.com/yxgeee/MMT
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def pairwise_euclidean_distance(x, y):
"""Compute pairwise euclidean distance between two sets of features"""
m, n = x.size(0), y.size(0)
dist_mat = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) + \
torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() \
- 2 * torch.matmul(x, y.t())
# for numerical stability
dist_mat = dist_mat.clamp(min=1e-12).sqrt()
return dist_mat
def hard_examples_mining(dist_mat, identity_mat, return_idxes=False):
r"""Select hard positives and hard negatives according to `In defense of the Triplet Loss for Person
Re-Identification (ICCV 2017) `_
Args:
dist_mat (tensor): pairwise distance matrix between two sets of features
identity_mat (tensor): a matrix of shape :math:`(N, M)`. If two images :math:`P[i]` of set :math:`P` and
:math:`Q[j]` of set :math:`Q` come from the same person, then :math:`identity\_mat[i, j] = 1`,
otherwise :math:`identity\_mat[i, j] = 0`
return_idxes (bool, optional): if True, also return indexes of hard examples. Default: False
"""
# the implementation here is a little tricky, dist_mat contains pairwise distance between probe image and other
# images in current mini-batch. As we want to select positive examples of the same person, we add a constant
# negative offset on other images before sorting. As a result, images of the **same** person will rank first.
sorted_dist_mat, sorted_idxes = torch.sort(dist_mat + (-1e7) * (1 - identity_mat), dim=1,
descending=True)
dist_ap = sorted_dist_mat[:, 0]
hard_positive_idxes = sorted_idxes[:, 0]
# the implementation here is similar to above code, we add a constant positive offset on images of same person
# before sorting. Besides, we sort in ascending order. As a result, images of **different** persons will rank first.
sorted_dist_mat, sorted_idxes = torch.sort(dist_mat + 1e7 * identity_mat, dim=1,
descending=False)
dist_an = sorted_dist_mat[:, 0]
hard_negative_idxes = sorted_idxes[:, 0]
if return_idxes:
return dist_ap, dist_an, hard_positive_idxes, hard_negative_idxes
return dist_ap, dist_an
class CrossEntropyLossWithLabelSmooth(nn.Module):
r"""Cross entropy loss with label smooth from `Rethinking the Inception Architecture for Computer Vision
(CVPR 2016) `_.
Given one-hot labels :math:`labels \in R^C`, where :math:`C` is the number of classes,
smoothed labels are calculated as
.. math::
smoothed\_labels = (1 - \epsilon) \times labels + \epsilon \times \frac{1}{C}
We use smoothed labels when calculating cross entropy loss and this can be helpful for preventing over-fitting.
Args:
num_classes (int): number of classes.
epsilon (float): a float number that controls the smoothness.
Inputs:
- y (tensor): unnormalized classifier predictions, :math:`y`
- labels (tensor): ground truth labels, :math:`labels`
Shape:
- y: :math:`(minibatch, C)`, where :math:`C` is the number of classes
- labels: :math:`(minibatch, )`
"""
def __init__(self, num_classes, epsilon=0.1):
super(CrossEntropyLossWithLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.log_softmax = nn.LogSoftmax(dim=1).cuda()
def forward(self, y, labels):
log_prob = self.log_softmax(y)
labels = torch.zeros_like(log_prob).scatter_(1, labels.unsqueeze(1), 1)
labels = (1 - self.epsilon) * labels + self.epsilon / self.num_classes
loss = (- labels * log_prob).mean(0).sum()
return loss
class TripletLoss(nn.Module):
"""Triplet loss augmented with batch hard from `In defense of the Triplet Loss for Person Re-Identification
(ICCV 2017) `_.
Args:
margin (float): margin of triplet loss
normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss.
Default: False.
"""
def __init__(self, margin, normalize_feature=False):
super(TripletLoss, self).__init__()
self.margin = margin
self.normalize_feature = normalize_feature
self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda()
def forward(self, f, labels):
if self.normalize_feature:
# equivalent to cosine similarity
f = F.normalize(f)
dist_mat = pairwise_euclidean_distance(f, f)
n = dist_mat.size(0)
identity_mat = labels.expand(n, n).eq(labels.expand(n, n).t()).float()
dist_ap, dist_an = hard_examples_mining(dist_mat, identity_mat)
y = torch.ones_like(dist_ap)
loss = self.margin_loss(dist_an, dist_ap, y)
return loss
class TripletLossXBM(nn.Module):
r"""Triplet loss augmented with batch hard from `In defense of the Triplet Loss for Person Re-Identification
(ICCV 2017) `_. The only difference from triplet loss lies in that
both features from current mini batch and external storage (XBM) are involved.
Args:
margin (float, optional): margin of triplet loss. Default: 0.3
normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss.
Default: False
Inputs:
- f (tensor): features of current mini batch, :math:`f`
- labels (tensor): identity labels for current mini batch, :math:`labels`
- xbm_f (tensor): features collected from XBM, :math:`xbm\_f`
- xbm_labels (tensor): corresponding identity labels of xbm_f, :math:`xbm\_labels`
Shape:
- f: :math:`(minibatch, F)`, where :math:`F` is the feature dimension
- labels: :math:`(minibatch, )`
- xbm_f: :math:`(minibatch, F)`
- xbm_labels: :math:`(minibatch, )`
"""
def __init__(self, margin=0.3, normalize_feature=False):
super(TripletLossXBM, self).__init__()
self.margin = margin
self.normalize_feature = normalize_feature
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, f, labels, xbm_f, xbm_labels):
if self.normalize_feature:
# equivalent to cosine similarity
f = F.normalize(f)
xbm_f = F.normalize(xbm_f)
dist_mat = pairwise_euclidean_distance(f, xbm_f)
# hard examples mining
n, m = f.size(0), xbm_f.size(0)
identity_mat = labels.expand(m, n).t().eq(xbm_labels.expand(n, m)).float()
dist_ap, dist_an = hard_examples_mining(dist_mat, identity_mat)
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return loss
class SoftTripletLoss(nn.Module):
r"""Soft triplet loss from `Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised
Domain Adaptation on Person Re-identification (ICLR 2020) `_.
Consider a triplet :math:`x,x_p,x_n` (anchor, positive, negative), corresponding features are :math:`f,f_p,f_n`.
We optimize for a smaller distance between :math:`f` and :math:`f_p` and a larger distance
between :math:`f` and :math:`f_n`. Inner product is adopted as their similarity measure, soft triplet loss is thus
defined as
.. math::
loss = \mathcal{L}_{\text{bce}}(\frac{\text{exp}(f^Tf_p)}{\text{exp}(f^Tf_p)+\text{exp}(f^Tf_n)}, 1)
where :math:`\mathcal{L}_{\text{bce}}` means binary cross entropy loss. We denote the first term in above loss function
as :math:`T`. When features from another teacher network can be obtained, we can calculate :math:`T_{teacher}` as
labels, resulting in the following soft version
.. math::
loss = \mathcal{L}_{\text{bce}}(T, T_{teacher})
Args:
margin (float, optional): margin of triplet loss. If None, soft labels from another network will be adopted when
computing loss. Default: None.
normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss.
Default: False.
"""
def __init__(self, margin=None, normalize_feature=False):
super(SoftTripletLoss, self).__init__()
self.margin = margin
self.normalize_feature = normalize_feature
def forward(self, features_1, features_2, labels):
if self.normalize_feature:
# equal to cosine similarity
features_1 = F.normalize(features_1)
features_2 = F.normalize(features_2)
dist_mat = pairwise_euclidean_distance(features_1, features_1)
assert dist_mat.size(0) == dist_mat.size(1)
n = dist_mat.size(0)
identity_mat = labels.expand(n, n).eq(labels.expand(n, n).t()).float()
dist_ap, dist_an, ap_idxes, an_idxes = hard_examples_mining(dist_mat, identity_mat, return_idxes=True)
assert dist_an.size(0) == dist_ap.size(0)
triple_dist = torch.stack((dist_ap, dist_an), dim=1)
triple_dist = F.log_softmax(triple_dist, dim=1)
if self.margin is not None:
loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean()
return loss
dist_mat_ref = pairwise_euclidean_distance(features_2, features_2)
dist_ap_ref = torch.gather(dist_mat_ref, 1, ap_idxes.view(n, 1).expand(n, n))[:, 0]
dist_an_ref = torch.gather(dist_mat_ref, 1, an_idxes.view(n, 1).expand(n, n))[:, 0]
triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1)
triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach()
loss = (- triple_dist_ref * triple_dist).sum(dim=1).mean()
return loss
class CrossEntropyLoss(nn.Module):
r"""We use :math:`C` to denote the number of classes, :math:`N` to denote mini-batch
size, this criterion expects unnormalized predictions :math:`y\_{logits}` of shape :math:`(N, C)` and
:math:`target\_{logits}` of the same shape :math:`(N, C)`. Then we first normalize them into
probability distributions among classes
.. math::
y = \text{softmax}(y\_{logits})
.. math::
target = \text{softmax}(target\_{logits})
Final objective is calculated as
.. math::
\text{loss} = \frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^C -target_i^j \times \text{log} (y_i^j)
"""
def __init__(self):
super(CrossEntropyLoss, self).__init__()
self.log_softmax = nn.LogSoftmax(dim=1).cuda()
def forward(self, y, labels):
log_prob = self.log_softmax(y)
loss = (- F.softmax(labels, dim=1).detach() * log_prob).sum(dim=1).mean()
return loss
================================================
FILE: tllib/vision/models/reid/resnet.py
================================================
"""
@author: Baixu Chen
@contact: cbx_99_hasta@outlook.com
"""
from tllib.vision.models.resnet import ResNet, load_state_dict_from_url, model_urls, BasicBlock, Bottleneck
__all__ = ['reid_resnet18', 'reid_resnet34', 'reid_resnet50', 'reid_resnet101']
class ReidResNet(ResNet):
r"""Modified `ResNet` architecture for ReID from `Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised
Domain Adaptation on Person Re-identification (ICLR 2020) `_. We change stride
of :math:`layer4\_group1\_conv2, layer4\_group1\_downsample1` to 1. During forward pass, we will not activate
`self.relu`. Please refer to source code for details.
"""
def __init__(self, *args, **kwargs):
super(ReidResNet, self).__init__(*args, **kwargs)
self.layer4[0].conv2.stride = (1, 1)
self.layer4[0].downsample[0].stride = (1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
# x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
def _reid_resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ReidResNet(block, layers, **kwargs)
if pretrained:
model_dict = model.state_dict()
pretrained_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
# remove keys from pretrained dict that doesn't appear in model dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model.load_state_dict(pretrained_dict, strict=False)
return model
def reid_resnet18(pretrained=False, progress=True, **kwargs):
r"""Constructs a Reid-ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _reid_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def reid_resnet34(pretrained=False, progress=True, **kwargs):
r"""Constructs a Reid-ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _reid_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def reid_resnet50(pretrained=False, progress=True, **kwargs):
r"""Constructs a Reid-ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _reid_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def reid_resnet101(pretrained=False, progress=True, **kwargs):
r"""Constructs a Reid-ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _reid_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
================================================
FILE: tllib/vision/models/resnet.py
================================================
"""
Modified based on torchvision.models.resnet.
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
from torchvision import models
from torch.hub import load_state_dict_from_url
from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls
import copy
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
class ResNet(models.ResNet):
"""ResNets without fully connected layer"""
def __init__(self, *args, **kwargs):
super(ResNet, self).__init__(*args, **kwargs)
self._out_features = self.fc.in_features
def forward(self, x):
""""""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = x.view(-1, self._out_features)
return x
@property
def out_features(self) -> int:
"""The dimension of output features"""
return self._out_features
def copy_head(self) -> nn.Module:
"""Copy the origin fully connected layer"""
return copy.deepcopy(self.fc)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
model_dict = model.state_dict()
pretrained_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
# remove keys from pretrained dict that doesn't appear in model dict
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model.load_state_dict(pretrained_dict, strict=False)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" `_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" `_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" `_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
================================================
FILE: tllib/vision/models/segmentation/__init__.py
================================================
from .deeplabv2 import *
__all__ = ['deeplabv2']
================================================
FILE: tllib/vision/models/segmentation/deeplabv2.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
model_urls = {
'deeplabv2_resnet101': 'https://cloud.tsinghua.edu.cn/f/2d9a7fc43ce34f76803a/?dl=1'
}
affine_par = True
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change
self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
for i in self.bn1.parameters():
i.requires_grad = False
padding = dilation
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change
padding=padding, bias=False, dilation=dilation)
self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
for i in self.bn2.parameters():
i.requires_grad = False
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
for i in self.bn3.parameters():
i.requires_grad = False
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ASPP_V2(nn.Module):
def __init__(self, inplanes, dilation_series, padding_series, num_classes):
super(ASPP_V2, self).__init__()
self.conv2d_list = nn.ModuleList()
for dilation, padding in zip(dilation_series, padding_series):
self.conv2d_list.append(
nn.Conv2d(inplanes, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation,
bias=True))
for m in self.conv2d_list:
m.weight.data.normal_(0, 0.01)
def forward(self, x):
out = self.conv2d_list[0](x)
for i in range(len(self.conv2d_list) - 1):
out += self.conv2d_list[i + 1](x)
return out
class ResNet(nn.Module):
def __init__(self, block, layers):
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
for i in self.bn1.parameters():
i.requires_grad = False
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.01)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion, affine=affine_par))
for i in downsample._modules['1'].parameters():
i.requires_grad = False
layers = []
layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
class Deeplab(nn.Module):
def __init__(self, backbone, classifier, num_classes):
super(Deeplab, self).__init__()
self.backbone = backbone
self.classifier = classifier
self.num_classes = num_classes
def forward(self, x):
x = self.backbone(x)
y = self.classifier(x)
return y
def get_1x_lr_params_NOscale(self):
"""
This generator returns all the parameters of the net except for
the last classification layer. Note that for each batchnorm layer,
requires_grad is set to False in deeplab_resnet.py, therefore this function does not return
any batchnorm parameter
"""
layers = [self.backbone.conv1, self.backbone.bn1,
self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4]
for layer in layers:
for module in layer.modules():
for param in module.parameters():
if param.requires_grad:
yield param
def get_10x_lr_params(self):
"""
This generator returns all the parameters for the last layer of the net,
which does the classification of pixel into classes
"""
for param in self.classifier.parameters():
yield param
def get_parameters(self, lr=1.):
return [
{'params': self.get_1x_lr_params_NOscale(), 'lr': 0.1 * lr},
{'params': self.get_10x_lr_params(), 'lr': lr}
]
def deeplabv2_resnet101(num_classes=19, pretrained_backbone=True):
"""Constructs a DeepLabV2 model with a ResNet-101 backbone.
Args:
num_classes (int, optional): number of classes. Default: 19
pretrained_backbone (bool, optional): If True, returns a model pre-trained on ImageNet. Default: True.
"""
backbone = ResNet(Bottleneck, [3, 4, 23, 3])
if pretrained_backbone:
# download from Internet
saved_state_dict = load_state_dict_from_url(model_urls['deeplabv2_resnet101'], map_location=lambda storage, loc: storage, file_name="deeplabv2_resnet101.pth")
new_params = backbone.state_dict().copy()
for i in saved_state_dict:
i_parts = i.split('.')
if not i_parts[1] == 'layer5':
new_params['.'.join(i_parts[1:])] = saved_state_dict[i]
backbone.load_state_dict(new_params)
classifier = ASPP_V2(2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes)
return Deeplab(backbone, classifier, num_classes)
================================================
FILE: tllib/vision/transforms/__init__.py
================================================
import math
import random
from PIL import Image
import numpy as np
import torch
from torchvision.transforms import Normalize
class ResizeImage(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
output size will be (size, size)
"""
def __init__(self, size):
if isinstance(size, int):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, img):
th, tw = self.size
return img.resize((th, tw))
def __repr__(self):
return self.__class__.__name__ + '(size={0})'.format(self.size)
class MultipleApply:
"""Apply a list of transformations to an image and get multiple transformed images.
Args:
transforms (list or tuple): list of transformations
Example:
>>> transform1 = T.Compose([
... ResizeImage(256),
... T.RandomCrop(224)
... ])
>>> transform2 = T.Compose([
... ResizeImage(256),
... T.RandomCrop(224),
... ])
>>> multiply_transform = MultipleApply([transform1, transform2])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image):
return [t(image) for t in self.transforms]
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class Denormalize(Normalize):
"""DeNormalize a tensor image with mean and standard deviation.
Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
channels, this transform will denormalize each channel of the input
``torch.*Tensor`` i.e.,
``output[channel] = input[channel] * std[channel] + mean[channel]``
.. note::
This transform acts out of place, i.e., it does not mutate the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""
def __init__(self, mean, std):
mean = np.array(mean)
std = np.array(std)
super().__init__((-mean / std).tolist(), (1 / std).tolist())
class NormalizeAndTranspose:
"""
First, normalize a tensor image with mean and standard deviation.
Then, convert the shape (H x W x C) to shape (C x H x W).
"""
def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)):
self.mean = np.array(mean, dtype=np.float32)
def __call__(self, image):
if isinstance(image, Image.Image):
image = np.asarray(image, np.float32)
# change to BGR
image = image[:, :, ::-1]
# normalize
image -= self.mean
image = image.transpose((2, 0, 1)).copy()
elif isinstance(image, torch.Tensor):
# change to BGR
image = image[:, :, [2, 1, 0]]
# normalize
image -= torch.from_numpy(self.mean).to(image.device)
image = image.permute((2, 0, 1))
else:
raise NotImplementedError(type(image))
return image
class DeNormalizeAndTranspose:
"""
First, convert a tensor image from the shape (C x H x W ) to shape (H x W x C).
Then, denormalize it with mean and standard deviation.
"""
def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)):
self.mean = np.array(mean, dtype=np.float32)
def __call__(self, image):
image = image.transpose((1, 2, 0))
# denormalize
image += self.mean
# change to RGB
image = image[:, :, ::-1]
return image
class RandomErasing(object):
"""Random erasing augmentation from `Random Erasing Data Augmentation (CVPR 2017)
`_. This augmentation randomly selects a rectangle region in an image
and erases its pixels.
Args:
probability (float): The probability that the Random Erasing operation will be performed.
sl (float): Minimum proportion of erased area against input image.
sh (float): Maximum proportion of erased area against input image.
r1 (float): Minimum aspect ratio of erased area.
mean (sequence): Value to fill the erased area.
"""
def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)):
self.probability = probability
self.mean = mean
self.sl = sl
self.sh = sh
self.r1 = r1
def __call__(self, img):
if random.uniform(0, 1) >= self.probability:
return img
for attempt in range(100):
area = img.size()[1] * img.size()[2]
target_area = random.uniform(self.sl, self.sh) * area
aspect_ratio = random.uniform(self.r1, 1 / self.r1)
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img.size()[2] and h < img.size()[1]:
x1 = random.randint(0, img.size()[1] - h)
y1 = random.randint(0, img.size()[2] - w)
if img.size()[0] == 3:
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
else:
img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
return img
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.probability)
================================================
FILE: tllib/vision/transforms/keypoint_detection.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
# TODO needs better documentation
import numpy as np
from PIL import ImageFilter, Image
import torchvision.transforms.functional as F
import torchvision.transforms.transforms as T
import numbers
import random
import math
import warnings
from typing import ClassVar
def wrapper(transform: ClassVar):
""" Wrap a transform for classification to a transform for keypoint detection.
Note that the keypoint detection label will keep the same before and after wrapper.
Args:
transform (class, callable): transform for classification
Returns:
transform for keypoint detection
"""
class WrapperTransform(transform):
def __call__(self, image, **kwargs):
image = super().__call__(image)
return image, kwargs
return WrapperTransform
ToTensor = wrapper(T.ToTensor)
Normalize = wrapper(T.Normalize)
ColorJitter = wrapper(T.ColorJitter)
def resize(image: Image.Image, size: int, interpolation=Image.BILINEAR,
keypoint2d: np.ndarray=None, intrinsic_matrix: np.ndarray=None):
width, height = image.size
assert width == height
factor = float(size) / float(width)
image = F.resize(image, size, interpolation)
keypoint2d = np.copy(keypoint2d)
keypoint2d *= factor
intrinsic_matrix = np.copy(intrinsic_matrix)
intrinsic_matrix[0][0] *= factor
intrinsic_matrix[0][2] *= factor
intrinsic_matrix[1][1] *= factor
intrinsic_matrix[1][2] *= factor
return image, keypoint2d, intrinsic_matrix
def crop(image: Image.Image, top, left, height, width, keypoint2d: np.ndarray):
image = F.crop(image, top, left, height, width)
keypoint2d = np.copy(keypoint2d)
keypoint2d[:, 0] -= left
keypoint2d[:, 1] -= top
return image, keypoint2d
def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR,
keypoint2d: np.ndarray=None, intrinsic_matrix: np.ndarray=None):
"""Crop the given PIL Image and resize it to desired size.
Notably used in :class:`~torchvision.transforms.RandomResizedCrop`.
Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
top (int): Vertical component of the top left corner of the crop box.
left (int): Horizontal component of the top left corner of the crop box.
height (int): Height of the crop box.
width (int): Width of the crop box.
size (sequence or int): Desired output size. Same semantics as ``resize``.
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``.
Returns:
PIL Image: Cropped image.
"""
assert isinstance(img, Image.Image), 'img should be PIL Image'
img, keypoint2d = crop(img, top, left, height, width, keypoint2d)
img, keypoint2d, intrinsic_matrix = resize(img, size, interpolation, keypoint2d, intrinsic_matrix)
return img, keypoint2d, intrinsic_matrix
def center_crop(image, output_size, keypoint2d: np.ndarray):
"""Crop the given PIL Image and resize it to desired size.
Args:
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
output_size (sequence or int): (height, width) of the crop box. If int,
it is used for both directions
Returns:
PIL Image: Cropped image.
"""
width, height = image.size
crop_height, crop_width = output_size
crop_top = int(round((height - crop_height) / 2.))
crop_left = int(round((width - crop_width) / 2.))
return crop(image, crop_top, crop_left, crop_height, crop_width, keypoint2d)
def hflip(image: Image.Image, keypoint2d: np.ndarray):
width, height = image.size
image = F.hflip(image)
keypoint2d = np.copy(keypoint2d)
keypoint2d[:, 0] = width - 1. - keypoint2d[:, 0]
return image, keypoint2d
def rotate(image: Image.Image, angle, keypoint2d: np.ndarray):
image = F.rotate(image, angle)
angle = -np.deg2rad(angle)
keypoint2d = np.copy(keypoint2d)
rotation_matrix = np.array([
[np.cos(angle), -np.sin(angle)],
[np.sin(angle), np.cos(angle)]
])
width, height = image.size
keypoint2d[:, 0] = keypoint2d[:, 0] - width / 2
keypoint2d[:, 1] = keypoint2d[:, 1] - height / 2
keypoint2d = np.matmul(rotation_matrix, keypoint2d.T).T
keypoint2d[:, 0] = keypoint2d[:, 0] + width / 2
keypoint2d[:, 1] = keypoint2d[:, 1] + height / 2
return image, keypoint2d
def resize_pad(img, keypoint2d, size, interpolation=Image.BILINEAR):
w, h = img.size
if w < h:
oh = size
ow = int(size * w / h)
img = img.resize((ow, oh), interpolation)
pad_top = pad_bottom = 0
pad_left = math.floor((size - ow) / 2)
pad_right = math.ceil((size - ow) / 2)
keypoint2d = keypoint2d * oh / h
keypoint2d[:, 0] += (size - ow) / 2
else:
ow = size
oh = int(size * h / w)
img = img.resize((ow, oh), interpolation)
pad_top = math.floor((size - oh) / 2)
pad_bottom = math.ceil((size - oh) / 2)
pad_left = pad_right = 0
keypoint2d = keypoint2d * ow / w
keypoint2d[:, 1] += (size - oh) / 2
keypoint2d[:, 0] += (size - ow) / 2
img = np.asarray(img)
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), 'constant', constant_values=0)
return Image.fromarray(img), keypoint2d
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, **kwargs):
for t in self.transforms:
image, kwargs = t(image, **kwargs)
return image, kwargs
class GaussianBlur(object):
def __init__(self, low=0, high=0.8):
self.low = low
self.high = high
def __call__(self, image: Image, **kwargs):
radius = np.random.uniform(low=self.low, high=self.high)
image = image.filter(ImageFilter.GaussianBlur(radius))
return image, kwargs
class Resize(object):
"""Resize the input PIL Image to the given size.
"""
def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int)
self.size = size
self.interpolation = interpolation
def __call__(self, image, keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, **kwargs):
image, keypoint2d, intrinsic_matrix = resize(image, self.size, self.interpolation, keypoint2d, intrinsic_matrix)
kwargs.update(keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)
if 'depth' in kwargs:
kwargs['depth'] = F.resize(kwargs['depth'], self.size)
return image, kwargs
class ResizePad(object):
"""Pad the given image on all sides with the given "pad" value to resize the image to the given size.
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
def __call__(self, img, keypoint2d, **kwargs):
image, keypoint2d = resize_pad(img, keypoint2d, self.size, self.interpolation)
kwargs.update(keypoint2d=keypoint2d)
return image, kwargs
class CenterCrop(object):
"""Crops the given PIL Image at the center.
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
def __call__(self, image, keypoint2d, **kwargs):
"""
Args:
img (PIL Image): Image to be cropped.
Returns:
PIL Image: Cropped image.
"""
image, keypoint2d = center_crop(image, self.size, keypoint2d)
kwargs.update(keypoint2d=keypoint2d)
if 'depth' in kwargs:
kwargs['depth'] = F.center_crop(kwargs['depth'], self.size)
return image, kwargs
class RandomRotation(object):
"""Rotate the image by angle.
Args:
degrees (sequence or float or int): Range of degrees to select from.
If degrees is a number instead of sequence like (min, max), the range of degrees
will be (-degrees, +degrees).
"""
def __init__(self, degrees):
if isinstance(degrees, numbers.Number):
if degrees < 0:
raise ValueError("If degrees is a single number, it must be positive.")
self.degrees = (-degrees, degrees)
else:
if len(degrees) != 2:
raise ValueError("If degrees is a sequence, it must be of len 2.")
self.degrees = degrees
@staticmethod
def get_params(degrees):
"""Get parameters for ``rotate`` for a random rotation.
Returns:
sequence: params to be passed to ``rotate`` for random rotation.
"""
angle = random.uniform(degrees[0], degrees[1])
return angle
def __call__(self, image, keypoint2d, **kwargs):
"""
Args:
img (PIL Image): Image to be rotated.
Returns:
PIL Image: Rotated image.
"""
angle = self.get_params(self.degrees)
image, keypoint2d = rotate(image, angle, keypoint2d)
kwargs.update(keypoint2d=keypoint2d)
if 'depth' in kwargs:
kwargs['depth'] = F.rotate(kwargs['depth'], angle)
return image, kwargs
class RandomResizedCrop(object):
"""Crop the given PIL Image to random size and aspect ratio.
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
This is popularly used to train the Inception networks.
Args:
size: expected output size of each edge
scale: range of size of the origin size cropped
ratio: range of aspect ratio of the origin aspect ratio cropped
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.6, 1.3), interpolation=Image.BILINEAR):
self.size = size
if scale[0] > scale[1]:
warnings.warn("range should be of kind (min, max)")
self.interpolation = interpolation
self.scale = scale
@staticmethod
def get_params(img, scale):
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Image to be cropped.
scale (tuple): range of size of the origin size cropped
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
sized crop.
"""
width, height = img.size
area = height * width
for attempt in range(10):
target_area = random.uniform(*scale) * area
aspect_ratio = 1
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to whole image
return 0, 0, height, width
def __call__(self, image, keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, **kwargs):
"""
Args:
img (PIL Image): Image to be cropped and resized.
Returns:
PIL Image: Randomly cropped and resized image.
"""
i, j, h, w = self.get_params(image, self.scale)
image, keypoint2d, intrinsic_matrix = resized_crop(image, i, j, h, w, self.size, self.interpolation, keypoint2d, intrinsic_matrix)
kwargs.update(keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix)
if 'depth' in kwargs:
kwargs['depth'] = F.resized_crop(kwargs['depth'], i, j, h, w, self.size, self.interpolation,)
return image, kwargs
class RandomApply(T.RandomTransforms):
"""Apply randomly a list of transformations with a given probability.
Args:
transforms (list or tuple or torch.nn.Module): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
self.p = p
def __call__(self, image, **kwargs):
if self.p < random.random():
return image, kwargs
for t in self.transforms:
image, kwargs = t(image, **kwargs)
return image, kwargs
================================================
FILE: tllib/vision/transforms/segmentation.py
================================================
"""
@author: Junguang Jiang
@contact: JiangJunguang1123@outlook.com
"""
from PIL import Image
import random
import math
from typing import ClassVar, Sequence, List, Tuple
from torch import Tensor
import torch
import torchvision.transforms.functional as F
import torchvision.transforms.transforms as T
import torch.nn as nn
from . import MultipleApply as MultipleApplyBase, NormalizeAndTranspose as NormalizeAndTransposeBase
def wrapper(transform: ClassVar):
""" Wrap a transform for classification to a transform for segmentation.
Note that the segmentation label will keep the same before and after wrapper.
Args:
transform (class, callable): transform for classification
Returns:
transform for segmentation
"""
class WrapperTransform(transform):
def __call__(self, image, label):
image = super().__call__(image)
return image, label
return WrapperTransform
ColorJitter = wrapper(T.ColorJitter)
Normalize = wrapper(T.Normalize)
ToTensor = wrapper(T.ToTensor)
ToPILImage = wrapper(T.ToPILImage)
MultipleApply = wrapper(MultipleApplyBase)
NormalizeAndTranspose = wrapper(NormalizeAndTransposeBase)
class Compose:
"""Composes several transforms together.
Args:
transforms (list): list of transforms to compose.
Example:
>>> Compose([
>>> Resize((512, 512)),
>>> RandomHorizontalFlip()
>>> ])
"""
def __init__(self, transforms):
super(Compose, self).__init__()
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class Resize(nn.Module):
"""Resize the input image and the corresponding label to the given size.
The image should be a PIL Image.
Args:
image_size (sequence): The requested image size in pixels, as a 2-tuple:
(width, height).
label_size (sequence, optional): The requested segmentation label size in pixels, as a 2-tuple:
(width, height). The same as image_size if None. Default: None.
"""
def __init__(self, image_size, label_size=None):
super(Resize, self).__init__()
self.image_size = image_size
if label_size is None:
self.label_size = image_size
else:
self.label_size = label_size
def forward(self, image, label):
"""
Args:
image: (PIL Image): Image to be scaled.
label: (PIL Image): Segmentation label to be scaled.
Returns:
Rescaled image, rescaled segmentation label
"""
# resize
image = image.resize(self.image_size, Image.BICUBIC)
label = label.resize(self.label_size, Image.NEAREST)
return image, label
class RandomCrop(nn.Module):
"""Crop the given image at a random location.
The image can be a PIL Image
Args:
size (sequence): Desired output size of the crop.
"""
def __init__(self, size):
super(RandomCrop, self).__init__()
self.size = size
def forward(self, image, label):
"""
Args:
image: (PIL Image): Image to be cropped.
label: (PIL Image): Segmentation label to be cropped.
Returns:
Cropped image, cropped segmentation label.
"""
# random crop
left = image.size[0] - self.size[0]
upper = image.size[1] - self.size[1]
left = random.randint(0, left-1)
upper = random.randint(0, upper-1)
right = left + self.size[0]
lower = upper + self.size[1]
image = image.crop((left, upper, right, lower))
label = label.crop((left, upper, right, lower))
return image, label
class RandomHorizontalFlip(nn.Module):
"""Horizontally flip the given PIL Image randomly with a given probability.
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super(RandomHorizontalFlip, self).__init__()
self.p = p
def forward(self, image, label):
"""
Args:
image: (PIL Image): Image to be flipped.
label: (PIL Image): Segmentation label to be flipped.
Returns:
Randomly flipped image, randomly flipped segmentation label.
"""
if random.random() < self.p:
return F.hflip(image), F.hflip(label)
return image, label
class RandomResizedCrop(T.RandomResizedCrop):
"""Crop the given image to random size and aspect ratio.
The image can be a PIL Image.
A crop of random size (default: of 0.5 to 1.0) of the original size and a random
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
is finally resized to given size.
Args:
size (int or sequence): expected output size of each edge. If size is an
int instead of sequence like (h, w), a square output size ``(size, size)`` is
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
scale (tuple of float): range of size of the origin size cropped
ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped.
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, scale=(0.5, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BICUBIC):
super(RandomResizedCrop, self).__init__(size, scale, ratio, interpolation)
@staticmethod
def get_params(
img: Tensor, scale: List[float], ratio: List[float]
) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random sized crop.
Args:
img (PIL Image): Input image.
scale (list): range of scale of the origin size cropped
ratio (list): range of aspect ratio of the origin aspect ratio cropped
Returns:
params (i, j, h, w) to be passed to ``crop`` for a random sized crop.
"""
width, height = F._get_image_size(img)
area = height * width
for _ in range(10):
target_area = area * random.uniform(scale[0], scale[1])
log_ratio = torch.log(torch.tensor(ratio))
aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1]))
w = int(round(math.sqrt(target_area * aspect_ratio)))
h = int(round(math.sqrt(target_area / aspect_ratio)))
if 0 < w <= width and 0 < h <= height:
i = random.randint(0, height - h)
j = random.randint(0, width - w)
return i, j, h, w
# Fallback to central crop
in_ratio = float(width) / float(height)
if in_ratio < min(ratio):
w = width
h = int(round(w / min(ratio)))
elif in_ratio > max(ratio):
h = height
w = int(round(h * max(ratio)))
else: # whole image
w = width
h = height
i = (height - h) // 2
j = (width - w) // 2
return i, j, h, w
def forward(self, image, label):
"""
Args:
image: (PIL Image): Image to be cropped and resized.
label: (PIL Image): Segmentation label to be cropped and resized.
Returns:
Randomly cropped and resized image, randomly cropped and resized segmentation label.
"""
top, left, height, width = self.get_params(image, self.scale, self.ratio)
image = image.crop((left, top, left + width, top + height))
image = image.resize(self.size, self.interpolation)
label = label.crop((left, top, left + width, top + height))
label = label.resize(self.size, Image.NEAREST)
return image, label
class RandomChoice(T.RandomTransforms):
"""Apply single transformation randomly picked from a list.
"""
def __call__(self, image, label):
t = random.choice(self.transforms)
return t(image, label)
class RandomApply(T.RandomTransforms):
"""Apply randomly a list of transformations with a given probability.
Args:
transforms (list or tuple or torch.nn.Module): list of transformations
p (float): probability
"""
def __init__(self, transforms, p=0.5):
super(RandomApply, self).__init__(transforms)
self.p = p
def __call__(self, image, label):
if self.p < random.random():
return image
for t in self.transforms:
image, label = t(image, label)
return image