Repository: thuml/Transfer-Learning-Library Branch: master Commit: c4aa59eb5656 Files: 445 Total size: 2.6 MB Directory structure: gitextract_h0zbpfvz/ ├── .github/ │ └── ISSUE_TEMPLATE/ │ ├── bug_report.md │ ├── custom.md │ └── feature_request.md ├── .gitignore ├── CONTRIBUTING.md ├── DATASETS.md ├── LICENSE ├── README.md ├── docs/ │ ├── Makefile │ ├── conf.py │ ├── index.rst │ ├── make.bat │ ├── requirements.txt │ └── tllib/ │ ├── alignment/ │ │ ├── domain_adversarial.rst │ │ ├── hypothesis_adversarial.rst │ │ ├── index.rst │ │ └── statistics_matching.rst │ ├── modules.rst │ ├── normalization.rst │ ├── ranking.rst │ ├── regularization.rst │ ├── reweight.rst │ ├── self_training.rst │ ├── translation.rst │ ├── utils/ │ │ ├── analysis.rst │ │ ├── base.rst │ │ ├── index.rst │ │ └── metric.rst │ └── vision/ │ ├── datasets.rst │ ├── index.rst │ ├── models.rst │ └── transforms.rst ├── examples/ │ ├── domain_adaptation/ │ │ ├── image_classification/ │ │ │ ├── README.md │ │ │ ├── adda.py │ │ │ ├── adda.sh │ │ │ ├── afn.py │ │ │ ├── afn.sh │ │ │ ├── bsp.py │ │ │ ├── bsp.sh │ │ │ ├── cc_loss.py │ │ │ ├── cc_loss.sh │ │ │ ├── cdan.py │ │ │ ├── cdan.sh │ │ │ ├── dan.py │ │ │ ├── dan.sh │ │ │ ├── dann.py │ │ │ ├── dann.sh │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── fixmatch.py │ │ │ ├── fixmatch.sh │ │ │ ├── jan.py │ │ │ ├── jan.sh │ │ │ ├── mcc.py │ │ │ ├── mcc.sh │ │ │ ├── mcd.py │ │ │ ├── mcd.sh │ │ │ ├── mdd.py │ │ │ ├── mdd.sh │ │ │ ├── requirements.txt │ │ │ ├── self_ensemble.py │ │ │ ├── self_ensemble.sh │ │ │ └── utils.py │ │ ├── image_regression/ │ │ │ ├── README.md │ │ │ ├── dann.py │ │ │ ├── dann.sh │ │ │ ├── dd.py │ │ │ ├── dd.sh │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── rsd.py │ │ │ ├── rsd.sh │ │ │ └── utils.py │ │ ├── keypoint_detection/ │ │ │ ├── README.md │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── regda.py │ │ │ ├── regda.sh │ │ │ ├── regda_fast.py │ │ │ └── regda_fast.sh │ │ ├── object_detection/ │ │ │ ├── README.md │ │ │ ├── config/ │ │ │ │ ├── faster_rcnn_R_101_C4_cityscapes.yaml │ │ │ │ ├── faster_rcnn_R_101_C4_voc.yaml │ │ │ │ ├── faster_rcnn_vgg_16_cityscapes.yaml │ │ │ │ └── retinanet_R_101_FPN_voc.yaml │ │ │ ├── cycle_gan.py │ │ │ ├── cycle_gan.sh │ │ │ ├── d_adapt/ │ │ │ │ ├── README.md │ │ │ │ ├── bbox_adaptation.py │ │ │ │ ├── category_adaptation.py │ │ │ │ ├── config/ │ │ │ │ │ ├── faster_rcnn_R_101_C4_cityscapes.yaml │ │ │ │ │ ├── faster_rcnn_R_101_C4_voc.yaml │ │ │ │ │ ├── faster_rcnn_vgg_16_cityscapes.yaml │ │ │ │ │ └── retinanet_R_101_FPN_voc.yaml │ │ │ │ ├── d_adapt.py │ │ │ │ └── d_adapt.sh │ │ │ ├── oracle.sh │ │ │ ├── prepare_cityscapes_to_voc.py │ │ │ ├── requirements.txt │ │ │ ├── source_only.py │ │ │ ├── source_only.sh │ │ │ ├── utils.py │ │ │ ├── visualize.py │ │ │ └── visualize.sh │ │ ├── openset_domain_adaptation/ │ │ │ ├── README.md │ │ │ ├── dann.py │ │ │ ├── dann.sh │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── osbp.py │ │ │ ├── osbp.sh │ │ │ └── utils.py │ │ ├── partial_domain_adaptation/ │ │ │ ├── README.md │ │ │ ├── afn.py │ │ │ ├── afn.sh │ │ │ ├── dann.py │ │ │ ├── dann.sh │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── iwan.py │ │ │ ├── iwan.sh │ │ │ ├── pada.py │ │ │ ├── pada.sh │ │ │ ├── requirements.txt │ │ │ └── utils.py │ │ ├── re_identification/ │ │ │ ├── README.md │ │ │ ├── baseline.py │ │ │ ├── baseline.sh │ │ │ ├── baseline_cluster.py │ │ │ ├── baseline_cluster.sh │ │ │ ├── ibn.sh │ │ │ ├── mmt.py │ │ │ ├── mmt.sh │ │ │ ├── requirements.txt │ │ │ ├── spgan.py │ │ │ ├── spgan.sh │ │ │ └── utils.py │ │ ├── semantic_segmentation/ │ │ │ ├── README.md │ │ │ ├── advent.py │ │ │ ├── advent.sh │ │ │ ├── cycada.py │ │ │ ├── cycada.sh │ │ │ ├── cycle_gan.py │ │ │ ├── cycle_gan.sh │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── fda.py │ │ │ └── fda.sh │ │ ├── wilds_image_classification/ │ │ │ ├── README.md │ │ │ ├── cdan.py │ │ │ ├── cdan.sh │ │ │ ├── dan.py │ │ │ ├── dan.sh │ │ │ ├── dann.py │ │ │ ├── dann.sh │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── fixmatch.py │ │ │ ├── fixmatch.sh │ │ │ ├── jan.py │ │ │ ├── jan.sh │ │ │ ├── mdd.py │ │ │ ├── mdd.sh │ │ │ ├── requirements.txt │ │ │ └── utils.py │ │ ├── wilds_ogb_molpcba/ │ │ │ ├── README.md │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── gin.py │ │ │ ├── requirements.txt │ │ │ └── utils.py │ │ ├── wilds_poverty/ │ │ │ ├── README.md │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── requirements.txt │ │ │ ├── resnet_ms.py │ │ │ └── utils.py │ │ └── wilds_text/ │ │ ├── README.md │ │ ├── erm.py │ │ ├── erm.sh │ │ ├── requirements.txt │ │ └── utils.py │ ├── domain_generalization/ │ │ ├── image_classification/ │ │ │ ├── README.md │ │ │ ├── coral.py │ │ │ ├── coral.sh │ │ │ ├── erm.py │ │ │ ├── erm.sh │ │ │ ├── groupdro.py │ │ │ ├── groupdro.sh │ │ │ ├── ibn.sh │ │ │ ├── irm.py │ │ │ ├── irm.sh │ │ │ ├── mixstyle.py │ │ │ ├── mixstyle.sh │ │ │ ├── mldg.py │ │ │ ├── mldg.sh │ │ │ ├── requirements.txt │ │ │ ├── utils.py │ │ │ ├── vrex.py │ │ │ └── vrex.sh │ │ └── re_identification/ │ │ ├── README.md │ │ ├── baseline.py │ │ ├── baseline.sh │ │ ├── ibn.sh │ │ ├── mixstyle.py │ │ ├── mixstyle.sh │ │ ├── requirements.txt │ │ └── utils.py │ ├── model_selection/ │ │ ├── README.md │ │ ├── hscore.py │ │ ├── hscore.sh │ │ ├── leep.py │ │ ├── leep.sh │ │ ├── logme.py │ │ ├── logme.sh │ │ ├── nce.py │ │ ├── nce.sh │ │ ├── requirements.txt │ │ └── utils.py │ ├── semi_supervised_learning/ │ │ └── image_classification/ │ │ ├── README.md │ │ ├── convert_moco_to_pretrained.py │ │ ├── debiasmatch.py │ │ ├── debiasmatch.sh │ │ ├── dst.py │ │ ├── dst.sh │ │ ├── erm.py │ │ ├── erm.sh │ │ ├── fixmatch.py │ │ ├── fixmatch.sh │ │ ├── flexmatch.py │ │ ├── flexmatch.sh │ │ ├── mean_teacher.py │ │ ├── mean_teacher.sh │ │ ├── noisy_student.py │ │ ├── noisy_student.sh │ │ ├── pi_model.py │ │ ├── pi_model.sh │ │ ├── pseudo_label.py │ │ ├── pseudo_label.sh │ │ ├── requirements.txt │ │ ├── self_tuning.py │ │ ├── self_tuning.sh │ │ ├── uda.py │ │ ├── uda.sh │ │ └── utils.py │ └── task_adaptation/ │ └── image_classification/ │ ├── README.md │ ├── bi_tuning.py │ ├── bi_tuning.sh │ ├── bss.py │ ├── bss.sh │ ├── co_tuning.py │ ├── co_tuning.sh │ ├── convert_moco_to_pretrained.py │ ├── delta.py │ ├── delta.sh │ ├── erm.py │ ├── erm.sh │ ├── lwf.py │ ├── lwf.sh │ ├── requirements.txt │ ├── stochnorm.py │ ├── stochnorm.sh │ └── utils.py ├── requirements.txt ├── setup.py └── tllib/ ├── __init__.py ├── alignment/ │ ├── __init__.py │ ├── adda.py │ ├── advent.py │ ├── bsp.py │ ├── cdan.py │ ├── coral.py │ ├── d_adapt/ │ │ ├── __init__.py │ │ ├── feedback.py │ │ ├── modeling/ │ │ │ ├── __init__.py │ │ │ ├── matcher.py │ │ │ ├── meta_arch/ │ │ │ │ ├── __init__.py │ │ │ │ ├── rcnn.py │ │ │ │ └── retinanet.py │ │ │ └── roi_heads/ │ │ │ ├── __init__.py │ │ │ ├── fast_rcnn.py │ │ │ └── roi_heads.py │ │ └── proposal.py │ ├── dan.py │ ├── dann.py │ ├── jan.py │ ├── mcd.py │ ├── mdd.py │ ├── osbp.py │ ├── regda.py │ └── rsd.py ├── modules/ │ ├── __init__.py │ ├── classifier.py │ ├── domain_discriminator.py │ ├── entropy.py │ ├── gl.py │ ├── grl.py │ ├── kernels.py │ ├── loss.py │ └── regressor.py ├── normalization/ │ ├── __init__.py │ ├── afn.py │ ├── ibn.py │ ├── mixstyle/ │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── sampler.py │ └── stochnorm.py ├── ranking/ │ ├── __init__.py │ ├── hscore.py │ ├── leep.py │ ├── logme.py │ ├── nce.py │ └── transrate.py ├── regularization/ │ ├── __init__.py │ ├── bi_tuning.py │ ├── bss.py │ ├── co_tuning.py │ ├── delta.py │ ├── knowledge_distillation.py │ └── lwf.py ├── reweight/ │ ├── __init__.py │ ├── groupdro.py │ ├── iwan.py │ └── pada.py ├── self_training/ │ ├── __init__.py │ ├── cc_loss.py │ ├── dst.py │ ├── flexmatch.py │ ├── mcc.py │ ├── mean_teacher.py │ ├── pi_model.py │ ├── pseudo_label.py │ ├── self_ensemble.py │ ├── self_tuning.py │ └── uda.py ├── translation/ │ ├── __init__.py │ ├── cycada.py │ ├── cyclegan/ │ │ ├── __init__.py │ │ ├── discriminator.py │ │ ├── generator.py │ │ ├── loss.py │ │ ├── transform.py │ │ └── util.py │ ├── fourier_transform.py │ └── spgan/ │ ├── __init__.py │ ├── loss.py │ └── siamese.py ├── utils/ │ ├── __init__.py │ ├── analysis/ │ │ ├── __init__.py │ │ ├── a_distance.py │ │ └── tsne.py │ ├── data.py │ ├── logger.py │ ├── meter.py │ ├── metric/ │ │ ├── __init__.py │ │ ├── keypoint_detection.py │ │ └── reid.py │ └── scheduler.py └── vision/ ├── __init__.py ├── datasets/ │ ├── __init__.py │ ├── _util.py │ ├── aircrafts.py │ ├── caltech101.py │ ├── cifar.py │ ├── coco70.py │ ├── cub200.py │ ├── digits.py │ ├── domainnet.py │ ├── dtd.py │ ├── eurosat.py │ ├── food101.py │ ├── imagelist.py │ ├── imagenet_r.py │ ├── imagenet_sketch.py │ ├── keypoint_detection/ │ │ ├── __init__.py │ │ ├── freihand.py │ │ ├── hand_3d_studio.py │ │ ├── human36m.py │ │ ├── keypoint_dataset.py │ │ ├── lsp.py │ │ ├── rendered_hand_pose.py │ │ ├── surreal.py │ │ └── util.py │ ├── object_detection/ │ │ └── __init__.py │ ├── office31.py │ ├── officecaltech.py │ ├── officehome.py │ ├── openset/ │ │ └── __init__.py │ ├── oxfordflowers.py │ ├── oxfordpets.py │ ├── pacs.py │ ├── partial/ │ │ ├── __init__.py │ │ ├── caltech_imagenet.py │ │ └── imagenet_caltech.py │ ├── patchcamelyon.py │ ├── regression/ │ │ ├── __init__.py │ │ ├── dsprites.py │ │ ├── image_regression.py │ │ └── mpi3d.py │ ├── reid/ │ │ ├── __init__.py │ │ ├── basedataset.py │ │ ├── convert.py │ │ ├── dukemtmc.py │ │ ├── market1501.py │ │ ├── msmt17.py │ │ ├── personx.py │ │ └── unreal.py │ ├── resisc45.py │ ├── retinopathy.py │ ├── segmentation/ │ │ ├── __init__.py │ │ ├── cityscapes.py │ │ ├── gta5.py │ │ ├── segmentation_list.py │ │ └── synthia.py │ ├── stanford_cars.py │ ├── stanford_dogs.py │ ├── sun397.py │ └── visda2017.py ├── models/ │ ├── __init__.py │ ├── digits.py │ ├── keypoint_detection/ │ │ ├── __init__.py │ │ ├── loss.py │ │ └── pose_resnet.py │ ├── object_detection/ │ │ ├── __init__.py │ │ ├── backbone/ │ │ │ ├── __init__.py │ │ │ ├── mmdetection/ │ │ │ │ ├── vgg.py │ │ │ │ └── weight_init.py │ │ │ └── vgg.py │ │ ├── meta_arch/ │ │ │ ├── __init__.py │ │ │ ├── rcnn.py │ │ │ └── retinanet.py │ │ ├── proposal_generator/ │ │ │ ├── __init__.py │ │ │ └── rpn.py │ │ └── roi_heads/ │ │ ├── __init__.py │ │ └── roi_heads.py │ ├── reid/ │ │ ├── __init__.py │ │ ├── identifier.py │ │ ├── loss.py │ │ └── resnet.py │ ├── resnet.py │ └── segmentation/ │ ├── __init__.py │ └── deeplabv2.py └── transforms/ ├── __init__.py ├── keypoint_detection.py └── segmentation.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: '' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error **Expected behavior** A clear and concise description of what you expected to happen. **Screenshots** If applicable, add screenshots to help explain your problem. **Desktop (please complete the following information):** - OS: [e.g. iOS] - Browser [e.g. chrome, safari] - Version [e.g. 22] **Smartphone (please complete the following information):** - Device: [e.g. iPhone6] - OS: [e.g. iOS8.1] - Browser [e.g. stock browser, safari] - Version [e.g. 22] **Additional context** Add any other context about the problem here. ================================================ FILE: .github/ISSUE_TEMPLATE/custom.md ================================================ --- name: Custom issue template about: Describe this issue template's purpose here. title: '' labels: '' assignees: '' --- ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: '' assignees: '' --- **Is your feature request related to a problem? Please describe.** A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] **Describe the solution you'd like** A clear and concise description of what you want to happen. **Describe alternatives you've considered** A clear and concise description of any alternative solutions or features you've considered. **Additional context** Add any other context or screenshots about the feature request here. ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ pip-wheel-metadata/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ docs/build/* docs/pytorch_sphinx_theme/* # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ .idea/ exp/* trash/* examples/domain_adaptation/digits/logs/* examples/domain_adaptation/digits/data/* +.DS_Store */.DS_Store ================================================ FILE: CONTRIBUTING.md ================================================ ## Contributing to Transfer-Learning-Library All kinds of contributions are welcome, including but not limited to the following. - Fix typo or bugs - Add documentation - Add new features and components ### Workflow 1. fork and pull the latest Transfer-Learning-Library repository 2. checkout a new branch (do not use master branch for PRs) 3. commit your changes 4. create a PR ```{note} If you plan to add some new features that involve large changes, it is encouraged to open an issue for discussion first. ``` ================================================ FILE: DATASETS.md ================================================ ## Notice (2023-08-01) ### Transfer-Learning-Library Dataset Link Failure Issue Dear users, We sincerely apologize to inform you that the dataset links of Transfer-Learning-Library have become invalid due to a cloud storage failure, resulting in many users being unable to download the datasets properly recently. We are working diligently to resolve this issue and plan to restore the links as soon as possible. Currently, we have already restored some dataset links, and they have been updated on the master branch. You can obtain the latest version by running "git pull." As the version on PyPI has not been updated yet, please temporarily uninstall the old version by running "pip uninstall tllib" before installing the new one. In the future, we are planning to store the datasets on both Baidu Cloud and Google Cloud to provide more stable download links. Additionally, a small portion of datasets that were backed up on our local server have also been lost due to a hard disk failure. For these datasets, we need to re-download and verify them, which might take longer to restore the links. Within this week, we will release the updated dataset and confirm the list of datasets without backup. For datasets without backup, if you have previously downloaded them locally, please contact us via email. Your support is highly appreciated. Once again, we apologize for any inconvenience caused and thank you for your understanding. Sincerely, The Transfer-Learning-Library Team ## Update (2023-08-09) Most of the dataset links have been restored at present. The confirmed datasets without backups are as follows: - Classification - COCO70 - EuroSAT - PACS - PatchCamelyon - [Partial Domain Adaptation] - CaltechImageNet - Keypoint Detection - Hand3DStudio - LSP - SURREAL - Object Detection - Comic - Re-Identification - PersonX - UnrealPerson **For these datasets, if you had previously downloaded them locally, please contact us via email. We greatly appreciate everyone's support.** ## Notice (2023-08-01) ### Transfer-Learning-Library数据集链接失效问题 各位使用者,我们很抱歉通知大家,最近Transfer-Learning-Library的数据集链接因为云盘故障而失效,导致很多使用者无法正常下载数据集。 我们正在全力以赴解决这一问题,并计划在最短的时间内恢复链接。目前我们已经恢复了部分数据集链接,更新在master分支上,您可以通过git pull来获取最新的版本。 由于pypi上的版本还未更新,暂时请首先通过pip uninstall tllib卸载旧版本。 日后我们计划将数据集存储在百度云和谷歌云上,提供更加稳定的下载链接。 另外,小部分数据集在我们本地服务器上的备份也由于硬盘故障而丢失,对于这些数据集我们需要重新下载并验证,可能需要更长的时间来恢复链接。 我们会在本周内发布已经更新的数据集和确认无备份的数据集列表,对于无备份的数据集,如果您之前有下载到本地,请通过邮件联系我们,非常感谢大家的支持。 再次向您表达我们的歉意,并感谢您的理解。 Transfer-Learning-Library团队 ## Update (2023-08-09) 目前大部分数据集的链接已经恢复,确认无备份的数据集如下: - Classification - COCO70 - EuroSAT - PACS - PatchCamelyon - [Partial Domain Adaptation] - CaltechImageNet - Keypoint Detection - Hand3DStudio - LSP - SURREAL - Object Detection - Comic - Re-Identification - PersonX - UnrealPerson **对于这些数据集,如果您之前有下载到本地,请通过邮件联系我们,非常感谢大家的支持。** ================================================ FILE: LICENSE ================================================ Copyright (c) 2018 The Python Packaging Authority Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================
# Transfer Learning Library - [Introduction](#introduction) - [Updates](#updates) - [Supported Methods](#supported-methods) - [Installation](#installation) - [Documentation](#documentation) - [Contact](#contact) - [Citation](#citation) ## Update (2024-03-15) We upload an offline version of documentation [here](/docs/html.zip). You can download and unzip it to view the documentation. ## Notice (2023-08-09) A note on broken dataset links can be found here: [DATASETS.md](DATASETS.md). ## Introduction *TLlib* is an open-source and well-documented library for Transfer Learning. It is based on pure PyTorch with high performance and friendly API. Our code is pythonic, and the design is consistent with torchvision. You can easily develop new algorithms, or readily apply existing algorithms. Our _API_ is divided by methods, which include: - domain alignment methods (tllib.aligment) - domain translation methods (tllib.translation) - self-training methods (tllib.self_training) - regularization methods (tllib.regularization) - data reweighting/resampling methods (tllib.reweight) - model ranking/selection methods (tllib.ranking) - normalization-based methods (tllib.normalization) We provide many example codes in the directory _examples_, which is divided by learning setups. Currently, the supported learning setups include: - DA (domain adaptation) - TA (task adaptation, also known as finetune) - OOD (out-of-distribution generalization, also known as DG / domain generalization) - SSL (semi-supervised learning) - Model Selection Our supported tasks include: classification, regression, object detection, segmentation, keypoint detection, and so on. ## Updates ### 2022.9 We support installing *TLlib* via `pip`, which is experimental currently. ```shell pip install -i https://test.pypi.org/simple/ tllib==0.4 ``` ### 2022.8 We release `v0.4` of *TLlib*. Previous versions of *TLlib* can be found [here](https://github.com/thuml/Transfer-Learning-Library/releases). In `v0.4`, we add implementations of the following methods: - Domain Adaptation for Object Detection [[Code]](/examples/domain_adaptation/object_detection) [[API]](/tllib/alignment/d_adapt) - Pre-trained Model Selection [[Code]](/examples/model_selection) [[API]](/tllib/ranking) - Semi-supervised Learning for Classification [[Code]](/examples/semi_supervised_learning/image_classification/) [[API]](/tllib/self_training) Besides, we maintain a collection of **_awesome papers in Transfer Learning_** in another repo [_A Roadmap for Transfer Learning_](https://github.com/thuml/A-Roadmap-for-Transfer-Learning). ### 2022.2 We adjusted our API following our survey [Transferablity in Deep Learning](https://arxiv.org/abs/2201.05867). ## Supported Methods The currently supported algorithms include: ##### Domain Adaptation for Classification [[Code]](/examples/domain_adaptation/image_classification) - **DANN** - Unsupervised Domain Adaptation by Backpropagation [[ICML 2015]](http://proceedings.mlr.press/v37/ganin15.pdf) [[Code]](/examples/domain_adaptation/image_classification/dann.py) - **DAN** - Learning Transferable Features with Deep Adaptation Networks [[ICML 2015]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/deep-adaptation-networks-icml15.pdf) [[Code]](/examples/domain_adaptation/image_classification/dan.py) - **JAN** - Deep Transfer Learning with Joint Adaptation Networks [[ICML 2017]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/joint-adaptation-networks-icml17.pdf) [[Code]](/examples/domain_adaptation/image_classification/jan.py) - **ADDA** - Adversarial Discriminative Domain Adaptation [[CVPR 2017]](http://openaccess.thecvf.com/content_cvpr_2017/papers/Tzeng_Adversarial_Discriminative_Domain_CVPR_2017_paper.pdf) [[Code]](/examples/domain_adaptation/image_classification/adda.py) - **CDAN** - Conditional Adversarial Domain Adaptation [[NIPS 2018]](http://papers.nips.cc/paper/7436-conditional-adversarial-domain-adaptation) [[Code]](/examples/domain_adaptation/image_classification/cdan.py) - **MCD** - Maximum Classifier Discrepancy for Unsupervised Domain Adaptation [[CVPR 2018]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Saito_Maximum_Classifier_Discrepancy_CVPR_2018_paper.pdf) [[Code]](/examples/domain_adaptation/image_classification/mcd.py) - **MDD** - Bridging Theory and Algorithm for Domain Adaptation [[ICML 2019]](http://proceedings.mlr.press/v97/zhang19i/zhang19i.pdf) [[Code]](/examples/domain_adaptation/image_classification/mdd.py) - **BSP** - Transferability vs. Discriminability: Batch Spectral Penalization for Adversarial Domain Adaptation [[ICML 2019]](http://proceedings.mlr.press/v97/chen19i/chen19i.pdf) [[Code]](/examples/domain_adaptation/image_classification/bsp.py) - **MCC** - Minimum Class Confusion for Versatile Domain Adaptation [[ECCV 2020]](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123660460.pdf) [[Code]](/examples/domain_adaptation/image_classification/mcc.py) ##### Domain Adaptation for Object Detection [[Code]](/examples/domain_adaptation/object_detection) - **CycleGAN** - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks [[ICCV 2017]](https://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf) [[Code]](/examples/domain_adaptation/object_detection/cycle_gan.py) - **D-adapt** - Decoupled Adaptation for Cross-Domain Object Detection [[ICLR 2022]](https://openreview.net/pdf?id=VNqaB1g9393) [[Code]](/examples/domain_adaptation/object_detection/d_adapt) ##### Domain Adaptation for Semantic Segmentation [[Code]](/examples/domain_adaptation/semantic_segmentation/) - **CycleGAN** - Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks [[ICCV 2017]](https://openaccess.thecvf.com/content_ICCV_2017/papers/Zhu_Unpaired_Image-To-Image_Translation_ICCV_2017_paper.pdf) [[Code]](/examples/domain_adaptation/semantic_segmentation/cycle_gan.py) - **CyCADA** - Cycle-Consistent Adversarial Domain Adaptation [[ICML 2018]](http://proceedings.mlr.press/v80/hoffman18a.html) [[Code]](/examples/domain_adaptation/semantic_segmentation/cycada.py) - **ADVENT** - Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation [[CVPR 2019]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Vu_ADVENT_Adversarial_Entropy_Minimization_for_Domain_Adaptation_in_Semantic_Segmentation_CVPR_2019_paper.pdf) [[Code]](/examples/domain_adaptation/semantic_segmentation/advent.py) - **FDA** - Fourier Domain Adaptation for Semantic Segmentation [[CVPR 2020]](https://arxiv.org/abs/2004.05498) [[Code]](/examples/domain_adaptation/semantic_segmentation/fda.py) ##### Domain Adaptation for Keypoint Detection [[Code]](/examples/domain_adaptation/keypoint_detection) - **RegDA** - Regressive Domain Adaptation for Unsupervised Keypoint Detection [[CVPR 2021]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/regressive-domain-adaptation-cvpr21.pdf) [[Code]](/examples/domain_adaptation/keypoint_detection) ##### Domain Adaptation for Person Re-identification [[Code]](/examples/domain_adaptation/re_identification/) - **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf) - **MMT** - Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification [[ICLR 2020]](https://arxiv.org/abs/2001.01526) [[Code]](/examples/domain_adaptation/re_identification/mmt.py) - **SPGAN** - Similarity Preserving Generative Adversarial Network [[CVPR 2018]](https://arxiv.org/pdf/1811.10551.pdf) [[Code]](/examples/domain_adaptation/re_identification/spgan.py) ##### Partial Domain Adaptation [[Code]](/examples/domain_adaptation/partial_domain_adaptation) - **IWAN** - Importance Weighted Adversarial Nets for Partial Domain Adaptation[[CVPR 2018]](https://arxiv.org/abs/1803.09210) [[Code]](/examples/domain_adaptation/partial_domain_adaptation/iwan.py) - **AFN** - Larger Norm More Transferable: An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation [[ICCV 2019]](https://arxiv.org/pdf/1811.07456v2.pdf) [[Code]](/examples/domain_adaptation/partial_domain_adaptation/afn.py) ##### Open-set Domain Adaptation [[Code]](/examples/domain_adaptation/openset_domain_adaptation) - **OSBP** - Open Set Domain Adaptation by Backpropagation [[ECCV 2018]](https://arxiv.org/abs/1804.10427) [[Code]](/examples/domain_adaptation/openset_domain_adaptation/osbp.py) ##### Domain Generalization for Classification [[Code]](/examples/domain_generalization/image_classification/) - **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf) - **MixStyle** - Domain Generalization with MixStyle [[ICLR 2021]](https://arxiv.org/abs/2104.02008) [[Code]](/examples/domain_generalization/image_classification/mixstyle.py) - **MLDG** - Learning to Generalize: Meta-Learning for Domain Generalization [[AAAI 2018]](https://arxiv.org/pdf/1710.03463.pdf) [[Code]](/examples/domain_generalization/image_classification/mldg.py) - **IRM** - Invariant Risk Minimization [[ArXiv]](https://arxiv.org/abs/1907.02893) [[Code]](/examples/domain_generalization/image_classification/irm.py) - **VREx** - Out-of-Distribution Generalization via Risk Extrapolation [[ICML 2021]](https://arxiv.org/abs/2003.00688) [[Code]](/examples/domain_generalization/image_classification/vrex.py) - **GroupDRO** - Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization [[ArXiv]](https://arxiv.org/abs/1911.08731) [[Code]](/examples/domain_generalization/image_classification/groupdro.py) - **Deep CORAL** - Correlation Alignment for Deep Domain Adaptation [[ECCV 2016]](https://arxiv.org/abs/1607.01719) [[Code]](/examples/domain_generalization/image_classification/coral.py) ##### Domain Generalization for Person Re-identification [[Code]](/examples/domain_generalization/re_identification/) - **IBN-Net** - Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net [[ECCV 2018]](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf) - **MixStyle** - Domain Generalization with MixStyle [[ICLR 2021]](https://arxiv.org/abs/2104.02008) [[Code]](/examples/domain_generalization/re_identification/mixstyle.py) ##### Task Adaptation (Fine-Tuning) for Image Classification [[Code]](/examples/task_adaptation/image_classification/) - **L2-SP** - Explicit inductive bias for transfer learning with convolutional networks [[ICML 2018]]((https://arxiv.org/abs/1802.01483)) [[Code]](/examples/task_adaptation/image_classification/delta.py) - **BSS** - Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning [[NIPS 2019]](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf) [[Code]](/examples/task_adaptation/image_classification/bss.py) - **DELTA** - DEep Learning Transfer using Fea- ture Map with Attention for convolutional networks [[ICLR 2019]](https://openreview.net/pdf?id=rkgbwsAcYm) [[Code]](/examples/task_adaptation/image_classification/delta.py) - **Co-Tuning** - Co-Tuning for Transfer Learning [[NIPS 2020]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/co-tuning-for-transfer-learning-nips20.pdf) [[Code]](/examples/task_adaptation/image_classification/co_tuning.py) - **StochNorm** - Stochastic Normalization [[NIPS 2020]](https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf) [[Code]](/examples/task_adaptation/image_classification/stochnorm.py) - **LWF** - Learning Without Forgetting [[ECCV 2016]](https://arxiv.org/abs/1606.09282) [[Code]](/examples/task_adaptation/image_classification/lwf.py) - **Bi-Tuning** - Bi-tuning of Pre-trained Representations [[ArXiv]](https://arxiv.org/abs/2011.06182?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29) [[Code]](/examples/task_adaptation/image_classification/bi_tuning.py) ##### Pre-trained Model Selection [[Code]](/examples/model_selection) - **H-Score** - An Information-theoretic Approach to Transferability in Task Transfer Learning [[ICIP 2019]](http://yangli-feasibility.com/home/media/icip-19.pdf) [[Code]](/examples/model_selection/hscore.py) - **NCE** - Negative Conditional Entropy in `Transferability and Hardness of Supervised Classification Tasks [[ICCV 2019]](https://arxiv.org/pdf/1908.08142v1.pdf) [[Code]](/examples/model_selection/nce.py) - **LEEP** - LEEP: A New Measure to Evaluate Transferability of Learned Representations [[ICML 2020]](http://proceedings.mlr.press/v119/nguyen20b/nguyen20b.pdf) [[Code]](/examples/model_selection/leep.py) - **LogME** - Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models for Transfer Learning [[ICML 2021]](https://arxiv.org/pdf/2102.11005.pdf) [[Code]](/examples/model_selection/logme.py) ##### Semi-Supervised Learning for Classification [[Code]](/examples/semi_supervised_learning/image_classification/) - **Pseudo Label** - Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks [[ICML 2013]](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf) [[Code]](/examples/semi_supervised_learning/image_classification/pseudo_label.py) - **Pi Model** - Temporal Ensembling for Semi-Supervised Learning [[ICLR 2017]](https://arxiv.org/abs/1610.02242) [[Code]](/examples/semi_supervised_learning/image_classification/pi_model.py) - **Mean Teacher** - Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results [[NIPS 2017]](https://arxiv.org/abs/1703.01780) [[Code]](/examples/semi_supervised_learning/image_classification/mean_teacher.py) - **Noisy Student** - Self-Training With Noisy Student Improves ImageNet Classification [[CVPR 2020]](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xie_Self-Training_With_Noisy_Student_Improves_ImageNet_Classification_CVPR_2020_paper.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/noisy_student.py) - **UDA** - Unsupervised Data Augmentation for Consistency Training [[NIPS 2020]](https://arxiv.org/pdf/1904.12848v4.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/uda.py) - **FixMatch** - Simplifying Semi-Supervised Learning with Consistency and Confidence [[NIPS 2020]](https://arxiv.org/abs/2001.07685) [[Code]](/examples/semi_supervised_learning/image_classification/fixmatch.py) - **Self-Tuning** - Self-Tuning for Data-Efficient Deep Learning [[ICML 2021]](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/self_tuning.py) - **FlexMatch** - FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling [[NIPS 2021]](https://arxiv.org/abs/2110.08263) [[Code]](/examples/semi_supervised_learning/image_classification/flexmatch.py) - **DebiasMatch** - Debiased Learning From Naturally Imbalanced Pseudo-Labels [[CVPR 2022]](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_Debiased_Learning_From_Naturally_Imbalanced_Pseudo-Labels_CVPR_2022_paper.pdf) [[Code]](/examples/semi_supervised_learning/image_classification/debiasmatch.py) - **DST** - Debiased Self-Training for Semi-Supervised Learning [[NIPS 2022 Oral]](https://arxiv.org/abs/2202.07136) [[Code]](/examples/semi_supervised_learning/image_classification/dst.py) ## Installation ##### Install from Source Code - Please git clone the library first. Then, run the following commands to install `tllib` and all the dependency. ```shell python setup.py install pip install -r requirements.txt ``` ##### Install via `pip` - Installing via `pip` is currently experimental. ```shell pip install -i https://test.pypi.org/simple/ tllib==0.4 ``` ## Documentation You can find the API documentation on the website: [Documentation](http://tl.thuml.ai/). ## Usage You can find examples in the directory `examples`. A typical usage is ```shell script # Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50. # Assume you have put the datasets under the path `data/office-31`, # or you are glad to download the datasets automatically from the Internet to this path python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 ``` ## Contributing We appreciate all contributions. If you are planning to contribute back bug-fixes, please do so without any further discussion. If you plan to contribute new features, utility functions or extensions, please first open an issue and discuss the feature with us. ## Disclaimer on Datasets This is a utility library that downloads and prepares public datasets. We do not host or distribute these datasets, vouch for their quality or fairness, or claim that you have licenses to use the dataset. It is your responsibility to determine whether you have permission to use the dataset under the dataset's license. If you're a dataset owner and wish to update any part of it (description, citation, etc.), or do not want your dataset to be included in this library, please get in touch through a GitHub issue. Thanks for your contribution to the ML community! ## Contact If you have any problem with our code or have some suggestions, including the future feature, feel free to contact - Baixu Chen (cbx_99_hasta@outlook.com) - Junguang Jiang (JiangJunguang1123@outlook.com) - Mingsheng Long (longmingsheng@gmail.com) or describe it in Issues. For Q&A in Chinese, you can choose to ask questions here before sending an email. [迁移学习算法库答疑专区](https://zhuanlan.zhihu.com/p/248104070) ## Citation If you use this toolbox or benchmark in your research, please cite this project. ```latex @misc{jiang2022transferability, title={Transferability in Deep Learning: A Survey}, author={Junguang Jiang and Yang Shu and Jianmin Wang and Mingsheng Long}, year={2022}, eprint={2201.05867}, archivePrefix={arXiv}, primaryClass={cs.LG} } @misc{tllib, author = {Junguang Jiang, Baixu Chen, Bo Fu, Mingsheng Long}, title = {Transfer-Learning-library}, year = {2020}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/thuml/Transfer-Learning-Library}}, } ``` ## Acknowledgment We would like to thank School of Software, Tsinghua University and The National Engineering Laboratory for Big Data Software for providing such an excellent ML research platform. ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SPHINXPROJ = PyTorchSphinxTheme SOURCEDIR = . BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/conf.py ================================================ import sys import os sys.path.append(os.path.abspath('..')) sys.path.append(os.path.abspath('./demo/')) from pytorch_sphinx_theme import __version__ import pytorch_sphinx_theme # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. #sys.path.insert(0, os.path.abspath('.')) # -- General configuration ----------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. #needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ 'sphinx.ext.intersphinx', 'sphinx.ext.autodoc', 'sphinx.ext.viewcode', 'sphinxcontrib.httpdomain', 'sphinx.ext.autosummary', 'sphinx.ext.autosectionlabel', 'sphinx.ext.napoleon', ] # build the templated autosummary files autosummary_generate = True numpydoc_show_class_members = False # autosectionlabel throws warnings if section names are duplicated. # The following tells autosectionlabel to not throw a warning for # duplicated section names that are in different documents. autosectionlabel_prefix_document = True napoleon_use_ivar = True # Do not warn about external images (status badges in README.rst) suppress_warnings = ['image.nonlocal_uri'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix of source filenames. source_suffix = '.rst' # The encoding of source files. #source_encoding = 'utf-8-sig' # The master toctree document. master_doc = 'index' # General information about the project. project = u'Transfer Learning Library' copyright = u'THUML Group' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # # The short X.Y version. version = __version__ # The full version, including alpha/beta/rc tags. release = __version__ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. language = 'en' # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: #today = '' # Else, today_fmt is used as the format for a strftime call. #today_fmt = '%B %d, %Y' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. exclude_patterns = [] # The reST default role (used for this markup: `text`) to use for all documents. #default_role = None # If true, '()' will be appended to :func: etc. cross-reference text. #add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). #add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. #show_authors = False # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # A list of ignored prefixes for module index sorting. #modindex_common_prefix = [] intersphinx_mapping = { 'rtd': ('https://docs.readthedocs.io/en/latest/', None), 'python': ('https://docs.python.org/3', None), 'numpy': ('https://numpy.org/doc/stable', None), 'torch': ('https://pytorch.org/docs/stable', None), 'torchvision': ('https://pytorch.org/vision/stable', None), 'PIL': ('https://pillow.readthedocs.io/en/stable/', None) } # -- Options for HTML output --------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. html_theme = 'pytorch_sphinx_theme' html_theme_path = [pytorch_sphinx_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { 'canonical_url': '', 'analytics_id': '', 'logo_only': False, 'display_version': False, 'prev_next_buttons_location': 'bottom', 'style_external_links': False, # Toc options 'collapse_navigation': True, 'sticky_navigation': False, 'navigation_depth': 4, 'includehidden': True, 'titles_only': False } # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". #html_title = None # A shorter title for the navigation bar. Default is the same as html_title. #html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. html_logo = "_static/images/TransLearn.png" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. #html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. #html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. #html_use_smartypants = True # Custom sidebar templates, maps document names to template names. #html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. #html_additional_pages = {} # If false, no module index is generated. #html_domain_indices = True # If false, no index is generated. #html_use_index = True # If true, the index is split into individual pages for each letter. #html_split_index = False # If true, links to the reST sources are added to the pages. html_show_sourcelink = True # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. #html_show_sphinx = True # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. #html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. #html_use_opensearch = '' # This is the file name suffix for HTML files (e.g. ".xhtml"). #html_file_suffix = None # Disable displaying type annotations, these can be very verbose autodoc_typehints = 'none' # Output file base name for HTML help builder. htmlhelp_basename = 'TransferLearningLibrary' # -- Options for LaTeX output -------------------------------------------------- latex_elements = { # The paper size ('letterpaper' or 'a4paper'). #'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt', # Additional stuff for the LaTeX preamble. #'preamble': '', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ ('index', 'TransferLearningLibrary.tex', u'Transfer Learning Library Documentation', u'THUML', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of # the title page. #latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. #latex_use_parts = False # If true, show page references after internal links. #latex_show_pagerefs = False # If true, show URL addresses after external links. #latex_show_urls = False # Documents to append as an appendix to all manuals. #latex_appendices = [] # If false, no module index is generated. #latex_domain_indices = True # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ ('index', 'Transfer Learning Library', u'Transfer Learning Library Documentation', [u'THUML'], 1) ] # If true, show URL addresses after external links. #man_show_urls = False # -- Options for Texinfo output ------------------------------------------------ # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ ('index', 'Transfer Learning Library', u'Transfer Learning Library Documentation', u'THUML', 'Transfer Learning Library', 'One line description of project.', 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. #texinfo_appendices = [] # If false, no module index is generated. #texinfo_domain_indices = True # How to display URL addresses: 'footnote', 'no', or 'inline'. #texinfo_show_urls = 'footnote' ================================================ FILE: docs/index.rst ================================================ ===================================== Transfer Learning ===================================== .. toctree:: :maxdepth: 2 :caption: Transfer Learning API :titlesonly: tllib/modules tllib/alignment/index tllib/translation tllib/self_training tllib/reweight tllib/normalization tllib/regularization tllib/ranking .. toctree:: :maxdepth: 2 :caption: Common API :titlesonly: tllib/vision/index tllib/utils/index ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=python -msphinx ) set SPHINXOPTS= set SPHINXBUILD=sphinx-build set SOURCEDIR=. set BUILDDIR=build set SPHINXPROJ=PyTorchSphinxTheme if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The Sphinx module was not found. Make sure you have Sphinx installed, echo.then set the SPHINXBUILD environment variable to point to the full echo.path of the 'sphinx-build' executable. Alternatively you may add the echo.Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% :end popd ================================================ FILE: docs/requirements.txt ================================================ sphinxcontrib-httpdomain sphinx ================================================ FILE: docs/tllib/alignment/domain_adversarial.rst ================================================ ========================================== Domain Adversarial Training ========================================== .. _DANN: DANN: Domain Adversarial Neural Network ---------------------------------------- .. autoclass:: tllib.alignment.dann.DomainAdversarialLoss .. _CDAN: CDAN: Conditional Domain Adversarial Network ----------------------------------------------- .. autoclass:: tllib.alignment.cdan.ConditionalDomainAdversarialLoss .. autoclass:: tllib.alignment.cdan.RandomizedMultiLinearMap .. autoclass:: tllib.alignment.cdan.MultiLinearMap .. _ADDA: ADDA: Adversarial Discriminative Domain Adaptation ----------------------------------------------------- .. autoclass:: tllib.alignment.adda.DomainAdversarialLoss .. note:: ADDAgrl is also implemented and benchmarked. You can find code `here `_. .. _BSP: BSP: Batch Spectral Penalization ----------------------------------- .. autoclass:: tllib.alignment.bsp.BatchSpectralPenalizationLoss .. _OSBP: OSBP: Open Set Domain Adaptation by Backpropagation ---------------------------------------------------- .. autoclass:: tllib.alignment.osbp.UnknownClassBinaryCrossEntropy .. _ADVENT: ADVENT: Adversarial Entropy Minimization for Semantic Segmentation ------------------------------------------------------------------ .. autoclass:: tllib.alignment.advent.Discriminator .. autoclass:: tllib.alignment.advent.DomainAdversarialEntropyLoss :members: .. _DADAPT: D-adapt: Decoupled Adaptation for Cross-Domain Object Detection ---------------------------------------------------------------- `Origin Paper `_. .. autoclass:: tllib.alignment.d_adapt.proposal.Proposal .. autoclass:: tllib.alignment.d_adapt.proposal.PersistentProposalList .. autoclass:: tllib.alignment.d_adapt.proposal.ProposalDataset .. autoclass:: tllib.alignment.d_adapt.modeling.meta_arch.DecoupledGeneralizedRCNN .. autoclass:: tllib.alignment.d_adapt.modeling.meta_arch.DecoupledRetinaNet ================================================ FILE: docs/tllib/alignment/hypothesis_adversarial.rst ================================================ ========================================== Hypothesis Adversarial Learning ========================================== .. _MCD: MCD: Maximum Classifier Discrepancy -------------------------------------------- .. autofunction:: tllib.alignment.mcd.classifier_discrepancy .. autofunction:: tllib.alignment.mcd.entropy .. autoclass:: tllib.alignment.mcd.ImageClassifierHead .. _MDD: MDD: Margin Disparity Discrepancy -------------------------------------------- .. autoclass:: tllib.alignment.mdd.MarginDisparityDiscrepancy **MDD for Classification** .. autoclass:: tllib.alignment.mdd.ClassificationMarginDisparityDiscrepancy .. autoclass:: tllib.alignment.mdd.ImageClassifier :members: .. autofunction:: tllib.alignment.mdd.shift_log **MDD for Regression** .. autoclass:: tllib.alignment.mdd.RegressionMarginDisparityDiscrepancy .. autoclass:: tllib.alignment.mdd.ImageRegressor .. _RegDA: RegDA: Regressive Domain Adaptation -------------------------------------------- .. autoclass:: tllib.alignment.regda.PseudoLabelGenerator2d .. autoclass:: tllib.alignment.regda.RegressionDisparity .. autoclass:: tllib.alignment.regda.PoseResNet2d ================================================ FILE: docs/tllib/alignment/index.rst ================================================ ===================================== Feature Alignment ===================================== .. toctree:: :maxdepth: 3 :caption: Feature Alignment :titlesonly: statistics_matching domain_adversarial hypothesis_adversarial ================================================ FILE: docs/tllib/alignment/statistics_matching.rst ================================================ ===================== Statistics Matching ===================== .. _DAN: DAN: Deep Adaptation Network ----------------------------- .. autoclass:: tllib.alignment.dan.MultipleKernelMaximumMeanDiscrepancy .. _CORAL: Deep CORAL: Correlation Alignment for Deep Domain Adaptation -------------------------------------------------------------- .. autoclass:: tllib.alignment.coral.CorrelationAlignmentLoss .. _JAN: JAN: Joint Adaptation Network ------------------------------ .. autoclass:: tllib.alignment.jan.JointMultipleKernelMaximumMeanDiscrepancy ================================================ FILE: docs/tllib/modules.rst ================================================ ===================== Modules ===================== Classifier ------------------------------- .. autoclass:: tllib.modules.classifier.Classifier :members: Regressor ------------------------------- .. autoclass:: tllib.modules.regressor.Regressor :members: Domain Discriminator ------------------------------- .. autoclass:: tllib.modules.domain_discriminator.DomainDiscriminator :members: GRL: Gradient Reverse Layer ----------------------------- .. autoclass:: tllib.modules.grl.WarmStartGradientReverseLayer :members: Gaussian Kernels ------------------------ .. autoclass:: tllib.modules.kernels.GaussianKernel Entropy ------------------------ .. autofunction:: tllib.modules.entropy.entropy Knowledge Distillation Loss ------------------------------- .. autoclass:: tllib.modules.loss.KnowledgeDistillationLoss :members: ================================================ FILE: docs/tllib/normalization.rst ================================================ ===================== Normalization ===================== .. _AFN: AFN: Adaptive Feature Norm ----------------------------- .. autoclass:: tllib.normalization.afn.AdaptiveFeatureNorm .. autoclass:: tllib.normalization.afn.Block .. autoclass:: tllib.normalization.afn.ImageClassifier StochNorm: Stochastic Normalization ------------------------------------------ .. autoclass:: tllib.normalization.stochnorm.StochNorm1d .. autoclass:: tllib.normalization.stochnorm.StochNorm2d .. autoclass:: tllib.normalization.stochnorm.StochNorm3d .. autofunction:: tllib.normalization.stochnorm.convert_model .. _IBN: IBN-Net: Instance-Batch Normalization Network ------------------------------------------------ .. autoclass:: tllib.normalization.ibn.InstanceBatchNorm2d .. autoclass:: tllib.normalization.ibn.IBNNet :members: .. automodule:: tllib.normalization.ibn :members: .. _MIXSTYLE: MixStyle: Domain Generalization with MixStyle ------------------------------------------------- .. autoclass:: tllib.normalization.mixstyle.MixStyle .. note:: MixStyle is only activated during `training` stage, with some probability :math:`p`. .. automodule:: tllib.normalization.mixstyle.resnet :members: ================================================ FILE: docs/tllib/ranking.rst ================================================ ===================== Ranking ===================== .. _H_score: H-score ------------------------------------------- .. autofunction:: tllib.ranking.hscore.h_score .. _LEEP: LEEP: Log Expected Empirical Prediction ------------------------------------------- .. autofunction:: tllib.ranking.leep.log_expected_empirical_prediction .. _NCE: NCE: Negative Conditional Entropy ------------------------------------------- .. autofunction:: tllib.ranking.nce.negative_conditional_entropy .. _LogME: LogME: Log Maximum Evidence ------------------------------------------- .. autofunction:: tllib.ranking.logme.log_maximum_evidence ================================================ FILE: docs/tllib/regularization.rst ================================================ =========================================== Regularization =========================================== .. _L2: L2 ------ .. autoclass:: tllib.regularization.delta.L2Regularization .. _L2SP: L2-SP ------ .. autoclass:: tllib.regularization.delta.SPRegularization .. _DELTA: DELTA: DEep Learning Transfer using Feature Map with Attention ------------------------------------------------------------------------------------- .. autoclass:: tllib.regularization.delta.BehavioralRegularization .. autoclass:: tllib.regularization.delta.AttentionBehavioralRegularization .. autoclass:: tllib.regularization.delta.IntermediateLayerGetter .. _LWF: LWF: Learning without Forgetting ------------------------------------------ .. autoclass:: tllib.regularization.lwf.Classifier .. _CoTuning: Co-Tuning ------------------------------------------ .. autoclass:: tllib.regularization.co_tuning.CoTuningLoss .. autoclass:: tllib.regularization.co_tuning.Relationship .. _StochNorm: .. _BiTuning: Bi-Tuning ------------------------------------------ .. autoclass:: tllib.regularization.bi_tuning.BiTuning .. _BSS: BSS: Batch Spectral Shrinkage ------------------------------------------ .. autoclass:: tllib.regularization.bss.BatchSpectralShrinkage ================================================ FILE: docs/tllib/reweight.rst ================================================ ======================================= Re-weighting ======================================= .. _PADA: PADA: Partial Adversarial Domain Adaptation --------------------------------------------- .. autoclass:: tllib.reweight.pada.ClassWeightModule .. autoclass:: tllib.reweight.pada.AutomaticUpdateClassWeightModule :members: .. autofunction:: tllib.reweight.pada.collect_classification_results .. _IWAN: IWAN: Importance Weighted Adversarial Nets --------------------------------------------- .. autoclass:: tllib.reweight.iwan.ImportanceWeightModule :members: .. _GroupDRO: GroupDRO: Group Distributionally robust optimization ------------------------------------------------------ .. autoclass:: tllib.reweight.groupdro.AutomaticUpdateDomainWeightModule :members: ================================================ FILE: docs/tllib/self_training.rst ================================================ ======================================= Self Training Methods ======================================= .. _PseudoLabel: Pseudo Label ----------------------------- .. autoclass:: tllib.self_training.pseudo_label.ConfidenceBasedSelfTrainingLoss .. _PiModel: :math:`\Pi` Model ----------------------------- .. autoclass:: tllib.self_training.pi_model.ConsistencyLoss .. autoclass:: tllib.self_training.pi_model.L2ConsistencyLoss .. _MeanTeacher: Mean Teacher ----------------------------- .. autoclass:: tllib.self_training.mean_teacher.EMATeacher .. _SelfEnsemble: Self Ensemble ----------------------------- .. autoclass:: tllib.self_training.self_ensemble.ClassBalanceLoss .. _UDA: UDA ----------------------------- .. autoclass:: tllib.self_training.uda.StrongWeakConsistencyLoss .. _MCC: MCC: Minimum Class Confusion ----------------------------- .. autoclass:: tllib.self_training.mcc.MinimumClassConfusionLoss .. _MMT: MMT: Mutual Mean-Teaching -------------------------- `Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification (ICLR 2020) `_ State of the art unsupervised domain adaptation methods utilize clustering algorithms to generate pseudo labels on target domain, which are noisy and thus harmful for training. Inspired by the teacher-student approaches, MMT framework provides robust soft pseudo labels in an on-line peer-teaching manner. We denote two networks as :math:`f_1,f_2`, their parameters as :math:`\theta_1,\theta_2`. The authors also propose to use the temporally average model of each network :math:`\text{ensemble}(f_1),\text{ensemble}(f_2)` to generate more reliable soft pseudo labels for supervising the other network. Specifically, the parameters of the temporally average models of the two networks at current iteration :math:`T` are denoted as :math:`E^{(T)}[\theta_1]` and :math:`E^{(T)}[\theta_2]` respectively, which can be calculated as .. math:: E^{(T)}[\theta_1] = \alpha E^{(T-1)}[\theta_1] + (1-\alpha)\theta_1 .. math:: E^{(T)}[\theta_2] = \alpha E^{(T-1)}[\theta_2] + (1-\alpha)\theta_2 where :math:`E^{(T-1)}[\theta_1],E^{(T-1)}[\theta_2]` indicate the temporal average parameters of the two networks in the previous iteration :math:`(T-1)`, the initial temporal average parameters are :math:`E^{(0)}[\theta_1]=\theta_1,E^{(0)}[\theta_2]=\theta_2` and :math:`\alpha` is the momentum. These two networks cooperate with each other in three ways: - When running clustering algorithm, we average features produced by :math:`\text{ensemble}(f_1)` and :math:`\text{ensemble}(f_2)` instead of only considering one of them. - A **soft triplet loss** is optimized between :math:`f_1` and :math:`\text{ensemble}(f_2)` and vice versa to force one network to learn from temporally average of another network. - A **cross entropy loss** is optimized between :math:`f_1` and :math:`\text{ensemble}(f_2)` and vice versa to force one network to learn from temporally average of another network. The above mentioned loss functions are listed below, more details can be found in training scripts. .. autoclass:: tllib.vision.models.reid.loss.SoftTripletLoss .. autoclass:: tllib.vision.models.reid.loss.CrossEntropyLoss .. _SelfTuning: Self Tuning ----------------------------- .. autoclass:: tllib.self_training.self_tuning.Classifier .. autoclass:: tllib.self_training.self_tuning.SelfTuning .. _FlexMatch: FlexMatch ----------------------------- .. autoclass:: tllib.self_training.flexmatch.DynamicThresholdingModule :members: .. _DST: Debiased Self-Training ----------------------------- .. autoclass:: tllib.self_training.dst.ImageClassifier .. autoclass:: tllib.self_training.dst.WorstCaseEstimationLoss ================================================ FILE: docs/tllib/translation.rst ================================================ ======================================= Domain Translation ======================================= .. _CycleGAN: ------------------------------------------------ CycleGAN: Cycle-Consistent Adversarial Networks ------------------------------------------------ Discriminator -------------- .. autofunction:: tllib.translation.cyclegan.pixel .. autofunction:: tllib.translation.cyclegan.patch Generator -------------- .. autofunction:: tllib.translation.cyclegan.resnet_9 .. autofunction:: tllib.translation.cyclegan.resnet_6 .. autofunction:: tllib.translation.cyclegan.unet_256 .. autofunction:: tllib.translation.cyclegan.unet_128 GAN Loss -------------- .. autoclass:: tllib.translation.cyclegan.LeastSquaresGenerativeAdversarialLoss .. autoclass:: tllib.translation.cyclegan.VanillaGenerativeAdversarialLoss .. autoclass:: tllib.translation.cyclegan.WassersteinGenerativeAdversarialLoss Translation -------------- .. autoclass:: tllib.translation.cyclegan.Translation Util ---------------- .. autoclass:: tllib.translation.cyclegan.util.ImagePool :members: .. autofunction:: tllib.translation.cyclegan.util.set_requires_grad .. _Cycada: -------------------------------------------------------------- CyCADA: Cycle-Consistent Adversarial Domain Adaptation -------------------------------------------------------------- .. autoclass:: tllib.translation.cycada.SemanticConsistency .. _SPGAN: ----------------------------------------------------------- SPGAN: Similarity Preserving Generative Adversarial Network ----------------------------------------------------------- `Image-Image Domain Adaptation with Preserved Self-Similarity and Domain-Dissimilarity for Person Re-identification `_. SPGAN is based on CycleGAN. An additional Siamese network is adopted to force the generator to produce images different from identities in target dataset. Siamese Network ------------------- .. autoclass:: tllib.translation.spgan.siamese.SiameseNetwork Contrastive Loss ------------------- .. autoclass:: tllib.translation.spgan.loss.ContrastiveLoss .. _FDA: ------------------------------------------------ FDA: Fourier Domain Adaptation ------------------------------------------------ .. autoclass:: tllib.translation.fourier_transform.FourierTransform .. autofunction:: tllib.translation.fourier_transform.low_freq_mutate ================================================ FILE: docs/tllib/utils/analysis.rst ================================================ ============== Analysis Tools ============== .. autofunction:: tllib.utils.analysis.collect_feature .. autofunction:: tllib.utils.analysis.a_distance.calculate .. autofunction:: tllib.utils.analysis.tsne.visualize ================================================ FILE: docs/tllib/utils/base.rst ================================================ Generic Tools ============== Average Meter --------------------------------- .. autoclass:: tllib.utils.meter.AverageMeter :members: Progress Meter --------------------------------- .. autoclass:: tllib.utils.meter.ProgressMeter :members: Meter --------------------------------- .. autoclass:: tllib.utils.meter.Meter :members: Data --------------------------------- .. autoclass:: tllib.utils.data.ForeverDataIterator :members: .. autoclass:: tllib.utils.data.CombineDataset :members: .. autofunction:: tllib.utils.data.send_to_device .. autofunction:: tllib.utils.data.concatenate Logger ----------- .. autoclass:: tllib.utils.logger.TextLogger :members: .. autoclass:: tllib.utils.logger.CompleteLogger :members: ================================================ FILE: docs/tllib/utils/index.rst ================================================ ===================================== Utilities ===================================== .. toctree:: :maxdepth: 2 :caption: Utilities :titlesonly: base metric analysis ================================================ FILE: docs/tllib/utils/metric.rst ================================================ =========== Metrics =========== Classification & Segmentation ============================== Accuracy --------------------------------- .. autofunction:: tllib.utils.metric.accuracy ConfusionMatrix --------------------------------- .. autoclass:: tllib.utils.metric.ConfusionMatrix :members: ================================================ FILE: docs/tllib/vision/datasets.rst ================================================ Datasets ============================= Cross-Domain Classification --------------------------------------------------------- -------------------------------------- ImageList -------------------------------------- .. autoclass:: tllib.vision.datasets.imagelist.ImageList :members: ------------------------------------- Office-31 ------------------------------------- .. autoclass:: tllib.vision.datasets.office31.Office31 :members: :inherited-members: --------------------------------------- Office-Caltech --------------------------------------- .. autoclass:: tllib.vision.datasets.officecaltech.OfficeCaltech :members: :inherited-members: --------------------------------------- Office-Home --------------------------------------- .. autoclass:: tllib.vision.datasets.officehome.OfficeHome :members: :inherited-members: -------------------------------------- VisDA-2017 -------------------------------------- .. autoclass:: tllib.vision.datasets.visda2017.VisDA2017 :members: :inherited-members: -------------------------------------- DomainNet -------------------------------------- .. autoclass:: tllib.vision.datasets.domainnet.DomainNet :members: :inherited-members: -------------------------------------- PACS -------------------------------------- .. autoclass:: tllib.vision.datasets.pacs.PACS :members: -------------------------------------- MNIST -------------------------------------- .. autoclass:: tllib.vision.datasets.digits.MNIST :members: -------------------------------------- USPS -------------------------------------- .. autoclass:: tllib.vision.datasets.digits.USPS :members: -------------------------------------- SVHN -------------------------------------- .. autoclass:: tllib.vision.datasets.digits.SVHN :members: Partial Cross-Domain Classification ---------------------------------------------------- --------------------------------------- Partial Wrapper --------------------------------------- .. autofunction:: tllib.vision.datasets.partial.partial .. autofunction:: tllib.vision.datasets.partial.default_partial --------------------------------------- Caltech-256->ImageNet-1k --------------------------------------- .. autoclass:: tllib.vision.datasets.partial.caltech_imagenet.CaltechImageNet :members: --------------------------------------- ImageNet-1k->Caltech-256 --------------------------------------- .. autoclass:: tllib.vision.datasets.partial.imagenet_caltech.ImageNetCaltech :members: Open Set Cross-Domain Classification ------------------------------------------------------ --------------------------------------- Open Set Wrapper --------------------------------------- .. autofunction:: tllib.vision.datasets.openset.open_set .. autofunction:: tllib.vision.datasets.openset.default_open_set Cross-Domain Regression ------------------------------------------------------ --------------------------------------- ImageRegression --------------------------------------- .. autoclass:: tllib.vision.datasets.regression.image_regression.ImageRegression :members: --------------------------------------- DSprites --------------------------------------- .. autoclass:: tllib.vision.datasets.regression.dsprites.DSprites :members: --------------------------------------- MPI3D --------------------------------------- .. autoclass:: tllib.vision.datasets.regression.mpi3d.MPI3D :members: Cross-Domain Segmentation ----------------------------------------------- --------------------------------------- SegmentationList --------------------------------------- .. autoclass:: tllib.vision.datasets.segmentation.segmentation_list.SegmentationList :members: --------------------------------------- Cityscapes --------------------------------------- .. autoclass:: tllib.vision.datasets.segmentation.cityscapes.Cityscapes --------------------------------------- GTA5 --------------------------------------- .. autoclass:: tllib.vision.datasets.segmentation.gta5.GTA5 --------------------------------------- Synthia --------------------------------------- .. autoclass:: tllib.vision.datasets.segmentation.synthia.Synthia --------------------------------------- Foggy Cityscapes --------------------------------------- .. autoclass:: tllib.vision.datasets.segmentation.cityscapes.FoggyCityscapes Cross-Domain Keypoint Detection ----------------------------------------------- --------------------------------------- Dataset Base for Keypoint Detection --------------------------------------- .. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.KeypointDataset :members: .. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.Body16KeypointDataset :members: .. autoclass:: tllib.vision.datasets.keypoint_detection.keypoint_dataset.Hand21KeypointDataset :members: --------------------------------------- Rendered Handpose Dataset --------------------------------------- .. autoclass:: tllib.vision.datasets.keypoint_detection.rendered_hand_pose.RenderedHandPose :members: --------------------------------------- Hand-3d-Studio Dataset --------------------------------------- .. autoclass:: tllib.vision.datasets.keypoint_detection.hand_3d_studio.Hand3DStudio :members: --------------------------------------- FreiHAND Dataset --------------------------------------- .. autoclass:: tllib.vision.datasets.keypoint_detection.freihand.FreiHand :members: --------------------------------------- Surreal Dataset --------------------------------------- .. autoclass:: tllib.vision.datasets.keypoint_detection.surreal.SURREAL :members: --------------------------------------- LSP Dataset --------------------------------------- .. autoclass:: tllib.vision.datasets.keypoint_detection.lsp.LSP :members: --------------------------------------- Human3.6M Dataset --------------------------------------- .. autoclass:: tllib.vision.datasets.keypoint_detection.human36m.Human36M :members: Cross-Domain ReID ------------------------------------------------------ --------------------------------------- Market1501 --------------------------------------- .. autoclass:: tllib.vision.datasets.reid.market1501.Market1501 :members: --------------------------------------- DukeMTMC-reID --------------------------------------- .. autoclass:: tllib.vision.datasets.reid.dukemtmc.DukeMTMC :members: --------------------------------------- MSMT17 --------------------------------------- .. autoclass:: tllib.vision.datasets.reid.msmt17.MSMT17 :members: Natural Object Recognition --------------------------------------------------------- ------------------------------------- Stanford Dogs ------------------------------------- .. autoclass:: tllib.vision.datasets.stanford_dogs.StanfordDogs :members: ------------------------------------- Stanford Cars ------------------------------------- .. autoclass:: tllib.vision.datasets.stanford_cars.StanfordCars :members: ------------------------------------- CUB-200-2011 ------------------------------------- .. autoclass:: tllib.vision.datasets.cub200.CUB200 :members: ------------------------------------- FVGC Aircraft ------------------------------------- .. autoclass:: tllib.vision.datasets.aircrafts.Aircraft :members: ------------------------------------- Oxford-IIIT Pets ------------------------------------- .. autoclass:: tllib.vision.datasets.oxfordpets.OxfordIIITPets :members: ------------------------------------- COCO-70 ------------------------------------- .. autoclass:: tllib.vision.datasets.coco70.COCO70 :members: ------------------------------------- DTD ------------------------------------- .. autoclass:: tllib.vision.datasets.dtd.DTD :members: ------------------------------------- OxfordFlowers102 ------------------------------------- .. autoclass:: tllib.vision.datasets.oxfordflowers.OxfordFlowers102 :members: ------------------------------------- Caltech101 ------------------------------------- .. autoclass:: tllib.vision.datasets.caltech101.Caltech101 :members: Specialized Image Classification -------------------------------- ------------------------------------- PatchCamelyon ------------------------------------- .. autoclass:: tllib.vision.datasets.patchcamelyon.PatchCamelyon :members: ------------------------------------- Retinopathy ------------------------------------- .. autoclass:: tllib.vision.datasets.retinopathy.Retinopathy :members: ------------------------------------- EuroSAT ------------------------------------- .. autoclass:: tllib.vision.datasets.eurosat.EuroSAT :members: ------------------------------------- Resisc45 ------------------------------------- .. autoclass:: tllib.vision.datasets.resisc45.Resisc45 :members: ------------------------------------- Food-101 ------------------------------------- .. autoclass:: tllib.vision.datasets.food101.Food101 :members: ------------------------------------- SUN397 ------------------------------------- .. autoclass:: tllib.vision.datasets.sun397.SUN397 :members: ================================================ FILE: docs/tllib/vision/index.rst ================================================ ===================================== Vision ===================================== .. toctree:: :maxdepth: 2 :caption: Vision :titlesonly: datasets models transforms ================================================ FILE: docs/tllib/vision/models.rst ================================================ Models =========================== ------------------------------ Image Classification ------------------------------ ResNets --------------------------------- .. automodule:: tllib.vision.models.resnet :members: LeNet -------------------------- .. automodule:: tllib.vision.models.digits.lenet :members: DTN -------------------------- .. automodule:: tllib.vision.models.digits.dtn :members: ---------------------------------- Object Detection ---------------------------------- .. autoclass:: tllib.vision.models.object_detection.meta_arch.TLGeneralizedRCNN :members: .. autoclass:: tllib.vision.models.object_detection.meta_arch.TLRetinaNet :members: .. autoclass:: tllib.vision.models.object_detection.proposal_generator.rpn.TLRPN .. autoclass:: tllib.vision.models.object_detection.roi_heads.TLRes5ROIHeads :members: .. autoclass:: tllib.vision.models.object_detection.roi_heads.TLStandardROIHeads :members: ---------------------------------- Semantic Segmentation ---------------------------------- .. autofunction:: tllib.vision.models.segmentation.deeplabv2.deeplabv2_resnet101 ---------------------------------- Keypoint Detection ---------------------------------- PoseResNet -------------------------- .. autofunction:: tllib.vision.models.keypoint_detection.pose_resnet.pose_resnet101 .. autoclass:: tllib.vision.models.keypoint_detection.pose_resnet.PoseResNet .. autoclass:: tllib.vision.models.keypoint_detection.pose_resnet.Upsampling Joint Loss ---------------------------------- .. autoclass:: tllib.vision.models.keypoint_detection.loss.JointsMSELoss .. autoclass:: tllib.vision.models.keypoint_detection.loss.JointsKLLoss ----------------------------------- Re-Identification ----------------------------------- Models --------------- .. autoclass:: tllib.vision.models.reid.resnet.ReidResNet .. automodule:: tllib.vision.models.reid.resnet :members: .. autoclass:: tllib.vision.models.reid.identifier.ReIdentifier :members: Loss ----------------------------------- .. autoclass:: tllib.vision.models.reid.loss.TripletLoss Sampler ----------------------------------- .. autoclass:: tllib.utils.data.RandomMultipleGallerySampler ================================================ FILE: docs/tllib/vision/transforms.rst ================================================ Transforms ============================= Classification --------------------------------- .. automodule:: tllib.vision.transforms :members: Segmentation --------------------------------- .. automodule:: tllib.vision.transforms.segmentation :members: Keypoint Detection --------------------------------- .. automodule:: tllib.vision.transforms.keypoint_detection :members: ================================================ FILE: examples/domain_adaptation/image_classification/README.md ================================================ # Unsupervised Domain Adaptation for Image Classification ## Installation It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results. Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You also need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset Following datasets can be downloaded automatically: - [MNIST](http://yann.lecun.com/exdb/mnist/), [SVHN](http://ufldl.stanford.edu/housenumbers/) , [USPS](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps) - [Office31](https://www.cc.gatech.edu/~judy/domainadapt/) - [OfficeCaltech](https://www.cc.gatech.edu/~judy/domainadapt/) - [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html) - [VisDA2017](http://ai.bu.edu/visda-2017/) - [DomainNet](http://ai.bu.edu/M3SDA/) You need to prepare following datasets manually if you want to use them: - [ImageNet](https://www.image-net.org/) - [ImageNetR](https://github.com/hendrycks/imagenet-r) - [ImageNet-Sketch](https://github.com/HaohanWang/ImageNet-Sketch) and prepare them following [Documentation for ImageNetR](/common/vision/datasets/imagenet_r.py) and [ImageNet-Sketch](/common/vision/datasets/imagenet_sketch.py). ## Supported Methods Supported methods include: - [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818) - [Deep Adaptation Network (DAN)](https://arxiv.org/pdf/1502.02791) - [Joint Adaptation Network (JAN)](https://arxiv.org/abs/1605.06636) - [Adversarial Discriminative Domain Adaptation (ADDA)](https://arxiv.org/pdf/1702.05464.pdf) - [Conditional Domain Adversarial Network (CDAN)](https://arxiv.org/abs/1705.10667) - [Maximum Classifier Discrepancy (MCD)](https://arxiv.org/abs/1712.02560) - [Adaptive Feature Norm (AFN)](https://arxiv.org/pdf/1811.07456v2.pdf) - [Batch Spectral Penalization (BSP)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/batch-spectral-penalization-icml19.pdf) - [Margin Disparity Discrepancy (MDD)](https://arxiv.org/abs/1904.05801) - [Minimum Class Confusion (MCC)](https://arxiv.org/abs/1912.03699) - [FixMatch](https://arxiv.org/abs/2001.07685) ## Usage The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to train DANN on Office31, use the following script ```shell script # Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50. # Assume you have put the datasets under the path `data/office-31`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W ``` Note that ``-s`` specifies the source domain, ``-t`` specifies the target domain, and ``--log`` specifies where to store results. After running the above command, it will download ``Office-31`` datasets from the Internet if it's the first time you run the code. Directory that stores datasets will be named as ``examples/domain_adaptation/image_classification/data/``. If everything works fine, you will see results in following format:: Epoch: [1][ 900/1000] Time 0.60 ( 0.69) Data 0.22 ( 0.31) Loss 0.74 ( 0.85) Cls Acc 96.9 (95.1) Domain Acc 64.1 (62.6) You can also watch these results in the log file ``logs/dann/Office31_A2W/log.txt``. After training, you can test your algorithm's performance by passing in ``--phase test``. ``` CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W --phase test ``` ## Experiment and Results **Notations** - ``Origin`` means the accuracy reported by the original paper. - ``Avg`` is the accuracy reported by `TLlib`. - ``ERM`` refers to the model trained with data from the source domain. - ``Oracle`` refers to the model trained with data from the target domain. We found that the accuracies of adversarial methods (including DANN, ADDA, CDAN, MCD, BSP and MDD) are not stable even after the random seed is fixed, thus we repeat running adversarial methods on *Office-31* and *VisDA-2017* for three times and report their average accuracy. ### Office-31 accuracy on ResNet-50 | Methods | Origin | Avg | A → W | D → W | W → D | A → D | D → A | W → A | |---------|--------|------|-------|-------|-------|-------|-------|-------| | ERM | 76.1 | 79.5 | 75.8 | 95.5 | 99.0 | 79.3 | 63.6 | 63.8 | | DANN | 82.2 | 86.1 | 91.4 | 97.9 | 100.0 | 83.6 | 73.3 | 70.4 | | ADDA | / | 87.3 | 94.6 | 97.5 | 99.7 | 90.0 | 69.6 | 72.5 | | BSP | 87.7 | 87.8 | 92.7 | 97.9 | 100.0 | 88.2 | 74.1 | 73.8 | | DAN | 80.4 | 83.7 | 84.2 | 98.4 | 100.0 | 87.3 | 66.9 | 65.2 | | JAN | 84.3 | 87.0 | 93.7 | 98.4 | 100.0 | 89.4 | 69.2 | 71.0 | | CDAN | 87.7 | 87.7 | 93.8 | 98.5 | 100.0 | 89.9 | 73.4 | 70.4 | | MCD | / | 85.4 | 90.4 | 98.5 | 100.0 | 87.3 | 68.3 | 67.6 | | AFN | 85.7 | 88.6 | 94.0 | 98.9 | 100.0 | 94.4 | 72.9 | 71.1 | | MDD | 88.9 | 89.6 | 95.6 | 98.6 | 100.0 | 94.4 | 76.6 | 72.2 | | MCC | 89.4 | 89.6 | 94.1 | 98.4 | 99.8 | 95.6 | 75.5 | 74.2 | | FixMatch| / | 86.4 | 86.4 | 98.2 | 100.0 | 95.4 | 70.0 | 68.1 | ### Office-Home accuracy on ResNet-50 | Methods | Origin | Avg | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr | |-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------| | ERM | 46.1 | 58.4 | 41.1 | 65.9 | 73.7 | 53.1 | 60.1 | 63.3 | 52.2 | 36.7 | 71.8 | 64.8 | 42.6 | 75.2 | | DAN | 56.3 | 61.4 | 45.6 | 67.7 | 73.9 | 57.7 | 63.8 | 66.0 | 54.9 | 40.0 | 74.5 | 66.2 | 49.1 | 77.9 | | DANN | 57.6 | 65.2 | 53.8 | 62.6 | 74.0 | 55.8 | 67.3 | 67.3 | 55.8 | 55.1 | 77.9 | 71.1 | 60.7 | 81.1 | | ADDA | / | 65.6 | 52.6 | 62.9 | 74.0 | 59.7 | 68.0 | 68.8 | 61.4 | 52.5 | 77.6 | 71.1 | 58.6 | 80.2 | | JAN | 58.3 | 65.9 | 50.8 | 71.9 | 76.5 | 60.6 | 68.3 | 68.7 | 60.5 | 49.6 | 76.9 | 71.0 | 55.9 | 80.5 | | CDAN | 65.8 | 68.8 | 55.2 | 72.4 | 77.6 | 62.0 | 69.7 | 70.9 | 62.4 | 54.3 | 80.5 | 75.5 | 61.0 | 83.8 | | MCD | / | 67.8 | 51.7 | 72.2 | 78.2 | 63.7 | 69.5 | 70.8 | 61.5 | 52.8 | 78.0 | 74.5 | 58.4 | 81.8 | | BSP | 64.9 | 67.6 | 54.7 | 67.7 | 76.2 | 61.0 | 69.4 | 70.9 | 60.9 | 55.2 | 80.2 | 73.4 | 60.3 | 81.2 | | AFN | 67.3 | 68.2 | 53.2 | 72.7 | 76.8 | 65.0 | 71.3 | 72.3 | 65.0 | 51.4 | 77.9 | 72.3 | 57.8 | 82.4 | | MDD | 68.1 | 69.7 | 56.2 | 75.4 | 79.6 | 63.5 | 72.1 | 73.8 | 62.5 | 54.8 | 79.9 | 73.5 | 60.9 | 84.5 | | MCC | / | 72.4 | 58.4 | 79.6 | 83.0 | 67.5 | 77.0 | 78.5 | 66.6 | 54.8 | 81.8 | 74.4 | 61.4 | 85.6 | | FixMatch | / | 70.8 | 56.4 | 76.4 | 79.9 | 65.3 | 73.8 | 71.2 | 67.2 | 56.4 | 80.6 | 74.9 | 63.5 | 84.3 | ### Office-Home accuracy on vit_base_patch16_224 (batch size 24) | Methods | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr | Avg | |-------------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|------| | Source Only | 52.4 | 82.1 | 86.9 | 76.8 | 84.1 | 86 | 75.1 | 51.2 | 88.1 | 78.3 | 51.5 | 87.8 | 75.0 | | DANN | 60.1 | 80.8 | 87.9 | 78.1 | 82.6 | 85.9 | 78.8 | 63.2 | 90.2 | 82.3 | 64 | 89.3 | 78.6 | | DAN | 56.3 | 83.6 | 87.5 | 77.7 | 84.7 | 86.7 | 75.9 | 54.5 | 88.5 | 80.2 | 56.2 | 88.2 | 76.7 | | JAN | 60.1 | 86.9 | 88.6 | 79.2 | 85.4 | 86.7 | 80.4 | 59.4 | 89.6 | 82 | 60.7 | 89.9 | 79.1 | | CDAN | 61.6 | 87.8 | 89.6 | 81.4 | 88.1 | 88.5 | 82.4 | 62.5 | 90.8 | 84.2 | 63.5 | 90.8 | 80.9 | | MCD | 52.3 | 75.3 | 85.3 | 75.4 | 75.4 | 78.3 | 68.8 | 49.7 | 86 | 80.6 | 60 | 89 | 73.0 | | AFN | 58.3 | 87.2 | 88.2 | 81.7 | 87 | 88.2 | 81 | 58.4 | 89.2 | 81.5 | 59.2 | 89.2 | 79.1 | | MDD | 64 | 89.3 | 90.4 | 82.2 | 87.7 | 89.2 | 82.8 | 64.9 | 91.7 | 83.7 | 65.4 | 92 | 81.9 | ### VisDA-2017 accuracy ResNet-101 | Methods | Origin | Mean | plane | bcycl | bus | car | horse | knife | mcycl | person | plant | sktbrd | train | truck | Avg | |-------------|--------|------|-------|-------|------|------|-------|-------|-------|--------|-------|--------|-------|-------|------| | ERM | 52.4 | 51.7 | 63.6 | 35.3 | 50.6 | 78.2 | 74.6 | 18.7 | 82.1 | 16.0 | 84.2 | 35.5 | 77.4 | 4.7 | 56.9 | | DANN | 57.4 | 79.5 | 93.5 | 74.3 | 83.4 | 50.7 | 87.2 | 90.2 | 89.9 | 76.1 | 88.1 | 91.4 | 89.7 | 39.8 | 74.9 | | ADDA | / | 77.5 | 95.6 | 70.8 | 84.4 | 54.0 | 87.8 | 75.8 | 88.4 | 69.3 | 84.1 | 86.2 | 85.0 | 48.0 | 74.3 | | BSP | 75.9 | 80.5 | 95.7 | 75.6 | 82.8 | 54.5 | 89.2 | 96.5 | 91.3 | 72.2 | 88.9 | 88.7 | 88.0 | 43.4 | 76.2 | | DAN | 61.1 | 66.4 | 89.2 | 37.2 | 77.7 | 61.8 | 81.7 | 64.3 | 90.6 | 61.4 | 79.9 | 37.7 | 88.1 | 27.4 | 67.2 | | JAN | / | 73.4 | 96.3 | 66.0 | 82.0 | 44.1 | 86.4 | 70.3 | 87.9 | 74.6 | 83.0 | 64.6 | 84.5 | 41.3 | 70.3 | | CDAN | / | 80.1 | 94.0 | 69.2 | 78.9 | 57.0 | 89.8 | 94.9 | 91.9 | 80.3 | 86.8 | 84.9 | 85.0 | 48.5 | 76.5 | | MCD | 71.9 | 77.7 | 87.8 | 75.7 | 84.2 | 78.1 | 91.6 | 95.3 | 88.1 | 78.3 | 83.4 | 64.5 | 84.8 | 20.9 | 76.7 | | AFN | 76.1 | 75.0 | 95.6 | 56.2 | 81.3 | 69.8 | 93.0 | 81.0 | 93.4 | 74.1 | 91.7 | 55.0 | 90.6 | 18.1 | 74.4 | | MDD | / | 82.0 | 88.3 | 62.8 | 85.2 | 69.9 | 91.9 | 95.1 | 94.4 | 81.2 | 93.8 | 89.8 | 84.1 | 47.9 | 79.8 | | MCC | 78.8 | 83.6 | 95.3 | 85.8 | 77.1 | 68.0 | 93.9 | 92.9 | 84.5 | 79.5 | 93.6 | 93.7 | 85.3 | 53.8 | 80.4 | | FixMatch | / | 79.5 | 96.5 | 76.6 | 72.6 | 84.6 | 96.3 | 92.6 | 90.5 | 81.8 | 91.9 | 74.6 | 87.3 | 8.6 | 78.4 | ### DomainNet accuracy on ResNet-101 | Methods | c->p | c->r | c->s | p->c | p->r | p->s | r->c | r->p | r->s | s->c | s->p | s->r | Avg | |-------------|------|------|------|------|------|------|------|------|------|------|------|------|------| | ERM | 32.7 | 50.6 | 39.4 | 41.1 | 56.8 | 35.0 | 48.6 | 48.8 | 36.1 | 49.0 | 34.8 | 46.1 | 43.3 | | DAN | 38.8 | 55.2 | 43.9 | 45.9 | 59.0 | 40.8 | 50.8 | 49.8 | 38.9 | 56.1 | 45.9 | 55.5 | 48.4 | | DANN | 37.9 | 54.3 | 44.4 | 41.7 | 55.6 | 36.8 | 50.7 | 50.8 | 40.1 | 55.0 | 45.0 | 54.5 | 47.2 | | JAN | 40.5 | 56.7 | 45.1 | 47.2 | 59.9 | 43.0 | 54.2 | 52.6 | 41.9 | 56.6 | 46.2 | 55.5 | 50.0 | | CDAN | 40.4 | 56.8 | 46.1 | 45.1 | 58.4 | 40.5 | 55.6 | 53.6 | 43.0 | 57.2 | 46.4 | 55.7 | 49.9 | | MCD | 37.5 | 52.9 | 44.0 | 44.6 | 54.5 | 41.6 | 52.0 | 51.5 | 39.7 | 55.5 | 44.6 | 52.0 | 47.5 | | MDD | 42.9 | 59.5 | 47.5 | 48.6 | 59.4 | 42.6 | 58.3 | 53.7 | 46.2 | 58.7 | 46.5 | 57.7 | 51.8 | | MCC | 37.7 | 55.7 | 42.6 | 45.4 | 59.8 | 39.9 | 54.4 | 53.1 | 37.0 | 58.1 | 46.3 | 56.2 | 48.9 | ### DomainNet accuracy on ResNet-101 (Multi-Source) | Methods | Origin | Avg | :c | :i | :p | :q | :r | :s | |-------------|--------|------|------|------|------|------|------|------| | ERM | 32.9 | 47.0 | 64.9 | 25.2 | 54.4 | 16.9 | 68.2 | 52.3 | | MDD | / | 48.8 | 68.7 | 29.7 | 58.2 | 9.7 | 69.4 | 56.9 | | Oracle | 63.0 | 69.1 | 78.2 | 40.7 | 71.6 | 69.7 | 83.8 | 70.6 | ### Performance on ImageNet-scale dataset | | ResNet50, ImageNet->ImageNetR | ig_resnext101_32x8d, ImageNet->ImageSketch | |------|-------------------------------|------------------------------------------| | ERM | 35.6 | 54.9 | | DAN | 39.8 | 55.7 | | DANN | 52.7 | 56.5 | | JAN | 41.7 | 55.7 | | CDAN | 53.9 | 58.2 | | MCD | 46.7 | 55.0 | | AFN | 43.0 | 55.1 | | MDD | 56.2 | 62.4 | ## Visualization After training `DANN`, run the following command ``` CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W --phase analysis ``` It may take a while, then in directory ``logs/dann/Office31_A2W/visualize``, you can find ``TSNE.png``. Following are the t-SNE of representations from ResNet50 trained on source domain and those from DANN. ## TODO 1. Support self-training methods 2. Support translation methods 3. Add results on ViT 4. Add results on ImageNet ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{DANN, author = {Ganin, Yaroslav and Lempitsky, Victor}, Booktitle = {ICML}, Title = {Unsupervised domain adaptation by backpropagation}, Year = {2015} } @inproceedings{DAN, author = {Mingsheng Long and Yue Cao and Jianmin Wang and Michael I. Jordan}, title = {Learning Transferable Features with Deep Adaptation Networks}, booktitle = {ICML}, year = {2015}, } @inproceedings{JAN, title={Deep transfer learning with joint adaptation networks}, author={Long, Mingsheng and Zhu, Han and Wang, Jianmin and Jordan, Michael I}, booktitle={ICML}, year={2017}, } @inproceedings{ADDA, title={Adversarial discriminative domain adaptation}, author={Tzeng, Eric and Hoffman, Judy and Saenko, Kate and Darrell, Trevor}, booktitle={CVPR}, year={2017} } @inproceedings{CDAN, author = {Mingsheng Long and Zhangjie Cao and Jianmin Wang and Michael I. Jordan}, title = {Conditional Adversarial Domain Adaptation}, booktitle = {NeurIPS}, year = {2018} } @inproceedings{MCD, title={Maximum classifier discrepancy for unsupervised domain adaptation}, author={Saito, Kuniaki and Watanabe, Kohei and Ushiku, Yoshitaka and Harada, Tatsuya}, booktitle={CVPR}, year={2018} } @InProceedings{AFN, author = {Xu, Ruijia and Li, Guanbin and Yang, Jihan and Lin, Liang}, title = {Larger Norm More Transferable: An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation}, booktitle = {ICCV}, year = {2019} } @inproceedings{MDD, title={Bridging theory and algorithm for domain adaptation}, author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael}, booktitle={ICML}, year={2019}, } @inproceedings{BSP, title={Transferability vs. discriminability: Batch spectral penalization for adversarial domain adaptation}, author={Chen, Xinyang and Wang, Sinan and Long, Mingsheng and Wang, Jianmin}, booktitle={ICML}, year={2019}, } @inproceedings{MCC, author = {Ying Jin and Ximei Wang and Mingsheng Long and Jianmin Wang}, title = {Less Confusion More Transferable: Minimum Class Confusion for Versatile Domain Adaptation}, year={2020}, booktitle={ECCV}, } @inproceedings{FixMatch, title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence}, author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang}, booktitle={NIPS}, year={2020} } ``` ================================================ FILE: examples/domain_adaptation/image_classification/adda.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com Note: Our implementation is different from ADDA paper in several respects. We do not use separate networks for source and target domain, nor fix classifier head. Besides, we do not adopt asymmetric objective loss function of the feature extractor. """ import random import time import warnings import copy import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.alignment.adda import ImageClassifier from tllib.alignment.dann import DomainAdversarialLoss from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.modules.grl import WarmStartGradientReverseLayer from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def set_requires_grad(net, requires_grad=False): """ Set requies_grad=Fasle for all the networks to avoid unnecessary computations """ for param in net.parameters(): param.requires_grad = requires_grad def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None source_classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) if args.phase == 'train' and args.pretrain is None: # first pretrain the classifier wish source data print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) pretrain_optimizer = SGD(pretrain_model.get_parameters(), args.pretrain_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) pretrain_lr_scheduler = LambdaLR(pretrain_optimizer, lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** ( -args.lr_decay)) # start pretraining for epoch in range(args.pretrain_epochs): print("lr:", pretrain_lr_scheduler.get_lr()) # pretrain for one epoch utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args, device) # validate to show pretrain process utils.validate(val_loader, pretrain_model, args, device) torch.save(pretrain_model.state_dict(), args.pretrain) print("Pretraining process is done.") checkpoint = torch.load(args.pretrain, map_location='cpu') source_classifier.load_state_dict(checkpoint) target_classifier = copy.deepcopy(source_classifier) # freeze source classifier set_requires_grad(source_classifier, False) source_classifier.freeze_bn() domain_discri = DomainDiscriminator(in_feature=source_classifier.features_dim, hidden_size=1024).to(device) # define loss function grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=2., max_iters=1000, auto_step=True) domain_adv = DomainAdversarialLoss(domain_discri, grl=grl).to(device) # define optimizer and lr scheduler # note that we only optimize target feature extractor optimizer = SGD(target_classifier.get_parameters(optimize_head=False) + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') target_classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(target_classifier.backbone, target_classifier.pool_layer, target_classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, target_classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, source_classifier, target_classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, target_classifier, args, device) # remember best acc@1 and save checkpoint torch.save(target_classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set target_classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, target_classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, source_model: ImageClassifier, target_model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses_transfer = AverageMeter('Transfer Loss', ':6.2f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_transfer, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode target_model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, = next(train_source_iter)[:1] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) # measure data loading time data_time.update(time.time() - end) _, f_s = source_model(x_s) _, f_t = target_model(x_t) loss_transfer = domain_adv(f_s, f_t) # Compute gradient and do SGD step optimizer.zero_grad() loss_transfer.backward() optimizer.step() lr_scheduler.step() losses_transfer.update(loss_transfer.item(), x_s.size(0)) domain_acc = domain_adv.domain_discriminator_accuracy domain_accs.update(domain_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='ADDA for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--pretrain', type=str, default=None, help='pretrain checkpoint for classification model') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate of the classifier', dest='lr') parser.add_argument('--pretrain-lr', default=0.001, type=float, help='initial pretrain learning rate') parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--pretrain-epochs', default=3, type=int, metavar='N', help='number of total epochs (pretrain) to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='adda', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/adda.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_A2W CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_D2W CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_W2D CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_A2D CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_D2A CUDA_VISIBLE_DEVICES=0 python adda.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/adda/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python adda.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/adda/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python adda.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/adda/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python adda.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python adda.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/adda/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python adda.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/adda/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python adda.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 --lr-d 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/adda/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/afn.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.normalization.afn import AdaptiveFeatureNorm, ImageClassifier from tllib.modules.entropy import entropy from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, args.num_blocks, bottleneck_dim=args.bottleneck_dim, dropout_p=args.dropout_p, pool_layer=pool_layer, finetune=not args.scratch).to(device) adaptive_feature_norm = AdaptiveFeatureNorm(args.delta).to(device) # define optimizer # the learning rate is fixed according to origin paper optimizer = SGD(classifier.get_parameters(), args.lr, weight_decay=args.weight_decay) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, adaptive_feature_norm, optimizer, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') norm_losses = AverageMeter('Norm Loss', ':3.2f') src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f') tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # norm loss norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t) loss = cls_loss + norm_loss * args.trade_off_norm # using entropy minimization if args.trade_off_entropy: y_t = F.softmax(y_t, dim=1) entropy_loss = entropy(y_t, reduction='mean') loss += entropy_loss * args.trade_off_entropy # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # update statistics cls_acc = accuracy(y_s, labels_s)[0] cls_losses.update(cls_loss.item(), x_s.size(0)) norm_losses.update(norm_loss.item(), x_s.size(0)) src_feature_norm.update(f_s.norm(p=2, dim=1).mean().item(), x_s.size(0)) tgt_feature_norm.update(f_t.norm(p=2, dim=1).mean().item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='AFN for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='ran.crop') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('-n', '--num-blocks', default=1, type=int, help='Number of basic blocks for classifier') parser.add_argument('--bottleneck-dim', default=1000, type=int, help='Dimension of bottleneck') parser.add_argument('--dropout-p', default=0.5, type=float, help='Dropout probability') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default: 5e-4)', dest='weight_decay') parser.add_argument('--trade-off-norm', default=0.05, type=float, help='the trade-off hyper-parameter for norm loss') parser.add_argument('--trade-off-entropy', default=None, type=float, help='the trade-off hyper-parameter for entropy loss') parser.add_argument('-r', '--delta', default=1, type=float, help='Increment for L2 norm') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='afn', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/afn.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2W CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2W CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2D CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2D CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2A CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/afn/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python afn.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 -r 0.3 -b 36 \ --epochs 10 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/afn/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off-norm 0.01 --lr 0.002 --log logs/afn/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python afn.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 20 -i 2500 --seed 0 --log logs/afn/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python afn.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 20 -i 2500 --seed 0 --log logs/afn_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/afn_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python afn.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/afn/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool -r 0.3 --lr 0.01 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/afn/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool -r 0.1 --lr 0.03 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/afn/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python afn.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool -r 0.1 --lr 0.1 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/afn/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/bsp.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.alignment.dann import DomainAdversarialLoss from tllib.alignment.bsp import BatchSpectralPenalizationLoss, ImageClassifier from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) bsp_penalty = BatchSpectralPenalizationLoss().to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return if args.pretrain is None: # first pretrain the classifier wish source data print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) pretrain_optimizer = SGD(pretrain_model.get_parameters(), args.pretrain_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) pretrain_lr_scheduler = LambdaLR(pretrain_optimizer, lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** ( -args.lr_decay)) # start pretraining for epoch in range(args.pretrain_epochs): print("lr:", pretrain_lr_scheduler.get_lr()) # pretrain for one epoch utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args, device) # validate to show pretrain process utils.validate(val_loader, pretrain_model, args, device) torch.save(pretrain_model.state_dict(), args.pretrain) print("Pretraining process is done.") checkpoint = torch.load(args.pretrain, map_location='cpu') classifier.load_state_dict(checkpoint) # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, bsp_penalty, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, bsp_penalty: BatchSpectralPenalizationLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) bsp_loss = bsp_penalty(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off + bsp_loss * args.trade_off_bsp cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='BSP for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--pretrain', type=str, default=None, help='pretrain checkpoint for classification model') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') parser.add_argument('--trade-off-bsp', default=2e-4, type=float, help='the trade-off hyper-parameter for bsp loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--pretrain-lr', default=0.001, type=float, help='initial pretrain learning rate') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--pretrain-epochs', default=3, type=int, metavar='N', help='number of total epochs(pretrain) to run (default: 3)') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='bsp', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/bsp.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_A2W CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_D2W CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_W2D CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_A2D CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_D2A CUDA_VISIBLE_DEVICES=0 python bsp.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/bsp/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python bsp.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/bsp/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python bsp.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/bsp/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python bsp.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python bsp.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/bsp/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python bsp.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/bsp/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python bsp.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.1 --trade-off-bsp 0.0001 -b 128 -i 2500 --scratch --seed 0 --log logs/bsp/SVHN2MNIST --pretrain-epochs 1 ================================================ FILE: examples/domain_adaptation/image_classification/cc_loss.py ================================================ """ @author: Ying Jin @contact: sherryying003@gmail.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.self_training.mcc import MinimumClassConfusionLoss, ImageClassifier from tllib.self_training.cc_loss import CCConsistency from tllib.vision.transforms import MultipleApply from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std, auto_augment=args.auto_augment) train_target_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_source_transform: ", train_source_transform) print("train_target_transform: ", train_target_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform, train_target_transform=train_target_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch train(train_source_iter, train_target_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # define loss function mcc = MinimumClassConfusionLoss(temperature=args.temperature) consistency = CCConsistency(temperature=args.temperature, thr=args.thr) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] (x_t, x_t_strong), labels_t = next(train_target_iter)[:2] x_s = x_s.to(device) x_t = x_t.to(device) x_t_strong = x_t_strong.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t, x_t_strong), dim=0) y, f = model(x) y_s, y_t, y_t_strong = y.chunk(3, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) mcc_loss = mcc(y_t) consistency_loss, selec_ratio = consistency(y_t, y_t_strong) loss = cls_loss + mcc_loss * args.trade_off + consistency_loss * args.trade_off_consistency transfer_loss = mcc_loss * args.trade_off + consistency_loss * args.trade_off_consistency cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='CC Loss for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--temperature', default=2.5, type=float, help='parameter temperature scaling') parser.add_argument('--thr', default=0.95, type=float, help='thr parameter for consistency loss') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for original mcc loss') parser.add_argument('--trade_off_consistency', default=1., type=float, help='the trade-off hyper-parameter for consistency loss') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.005, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='mcc', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/cc_loss.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_A2W CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_D2W CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_W2D CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_A2D CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_D2A CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/cc_loss/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/cc_loss/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=5 python cc_loss.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 30 --seed 0 --lr 0.002 --per-class-eval --temperature 3.0 --train-resizing cen.crop --log logs/cc_loss/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/cc_loss/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 --seed 0 --temperature 2.5 --bottleneck-dim 2048 --log logs/cc_loss/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/cc_loss_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/cc_loss_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/cc_loss/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cc_loss/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python cc_loss.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cc_loss/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/cdan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) classifier_feature_dim = classifier.features_dim if args.randomized: domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024).to(device) else: domain_discri = DomainDiscriminator(classifier_feature_dim * num_classes, hidden_size=1024).to(device) all_parameters = classifier.get_parameters() + domain_discri.get_parameters() # define optimizer and lr scheduler optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv = ConditionalDomainAdversarialLoss( domain_discri, entropy_conditioning=args.entropy, num_classes=num_classes, features_dim=classifier_feature_dim, randomized=args.randomized, randomized_dim=args.randomized_dim ).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: ConditionalDomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(y_s, f_s, y_t, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc, x_s.size(0)) domain_accs.update(domain_acc, x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='CDAN for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('-r', '--randomized', action='store_true', help='using randomized multi-linear-map (default: False)') parser.add_argument('-rd', '--randomized-dim', default=1024, type=int, help='randomized dimension when using randomized multi-linear-map (default: 1024)') parser.add_argument('--entropy', default=False, action='store_true', help='use entropy conditioning') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='cdan', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/cdan.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_A2W CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_D2W CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_W2D CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_A2D CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_D2A CUDA_VISIBLE_DEVICES=0 python cdan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 2 --log logs/cdan/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python cdan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/cdan/VisDA2017 # ResNet101, DomainNet, Single Source # Use randomized multi-linear-map to decrease GPU memory usage CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python cdan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/cdan/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python cdan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --log logs/cdan_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/cdan_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python cdan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/cdan/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python cdan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 -r -rd 51200 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/cdan/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python cdan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/cdan/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/dan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.alignment.dan import MultipleKernelMaximumMeanDiscrepancy, ImageClassifier from tllib.modules.kernels import GaussianKernel from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function mkmmd_loss = MultipleKernelMaximumMeanDiscrepancy( kernels=[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)], linear=not args.non_linear ) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, mkmmd_loss, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, mkmmd_loss: MultipleKernelMaximumMeanDiscrepancy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':5.4f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() mkmmd_loss.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = mkmmd_loss(f_s, f_t) loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='DAN for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--non-linear', default=False, action='store_true', help='whether not use the linear version') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='dan', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/dan.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_D2A CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_W2A CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_A2W CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_A2D CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_D2W CUDA_VISIBLE_DEVICES=0 python dan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/dan/Office31_W2D # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --log logs/dan/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python dan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 20 -i 500 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/dan/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python dan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dan/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python dan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --log logs/dan_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/dan_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python dan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dan/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python dan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dan/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python dan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dan/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/dann.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='DANN for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='dann', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/dann.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_D2W CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_W2D CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2D CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_D2A CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 30 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/dann/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/dann/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python dann.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --bottleneck-dim 1024 --log logs/dann_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --log logs/dann_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python dann.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/dann/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python dann.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/dann/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/erm.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.modules.classifier import Classifier from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance from tllib.utils.data import ForeverDataIterator device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch utils.empirical_risk_minimization(train_source_iter, classifier, optimizer, lr_scheduler, epoch, args, device) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Source Only for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='src_only', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/erm.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2W CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2W CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2D CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2D CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2A CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 20 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/erm/VisDA2017 # ResNet101, DomainNet, Oracle CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_c CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i -t i -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_i CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_p CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s q -t q -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_q CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_r CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/oracle/DomainNet_s # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 20 -i 2500 --seed 0 --log logs/erm/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python erm.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 20 -i 2500 --seed 0 --log logs/erm_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 20 -i 1000 -b 24 --seed 0 --log logs/erm_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 10 -i 1000 --seed 0 --log logs/erm/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --epochs 20 -i 2500 --seed 0 --lr 0.01 --log logs/erm/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python erm.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/erm/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/fixmatch.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.classifier import Classifier from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss from tllib.vision.transforms import MultipleApply from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class ImageClassifier(Classifier): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) def forward(self, x: torch.Tensor): """""" f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) predictions = self.head(f) return predictions def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std, auto_augment=args.auto_augment) train_target_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_source_transform: ", train_source_transform) print("train_target_transform: ", train_target_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform, train_target_transform=train_target_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.unlabeled_batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) print(classifier) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()) # train for one epoch train(train_source_iter, train_target_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') cls_losses = AverageMeter('Cls Loss', ':6.2f') self_training_losses = AverageMeter('Self Training Loss', ':6.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f') pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs, pseudo_label_ratios], prefix="Epoch: [{}]".format(epoch)) self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] (x_t, x_t_strong), labels_t = next(train_target_iter)[:2] x_s = x_s.to(device) x_t = x_t.to(device) x_t_strong = x_t_strong.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output with torch.no_grad(): y_t = model(x_t) # cross entropy loss y_s = model(x_s) cls_loss = F.cross_entropy(y_s, labels_s) cls_loss.backward() # self-training loss y_t_strong = model(x_t_strong) self_training_loss, mask, pseudo_labels = self_training_criterion(y_t_strong, y_t) self_training_loss = args.trade_off * self_training_loss self_training_loss.backward() # measure accuracy and record loss loss = cls_loss + self_training_loss losses.update(loss.item(), x_s.size(0)) cls_losses.update(cls_loss.item(), x_s.size(0)) self_training_losses.update(self_training_loss.item(), x_s.size(0)) cls_acc = accuracy(y_s, labels_s)[0] cls_accs.update(cls_acc.item(), x_s.size(0)) # ratio of pseudo labels n_pseudo_labels = mask.sum() ratio = n_pseudo_labels / x_t.size(0) pseudo_label_ratios.update(ratio.item() * 100, x_t.size(0)) # accuracy of pseudo labels if n_pseudo_labels > 0: pseudo_labels = pseudo_labels * mask - (1 - mask) n_correct = (pseudo_labels == labels_t).float().sum() pseudo_label_acc = n_correct / n_pseudo_labels * 100 pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='FixMatch for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', help='Random resize scale (default: 0.5 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('-ub', '--unlabeled-batch-size', default=32, type=int, help='mini-batch size of unlabeled data (target domain) (default: 32)') parser.add_argument('--threshold', default=0.9, type=float, help='confidence threshold') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='fixmatch', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/fixmatch.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s A -t W -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_A2W CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s D -t W -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_D2W CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s W -t D -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_W2D CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s A -t D -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_A2D CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s D -t A -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_D2A CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office31 -d Office31 -s W -t A -a resnet50 --lr 0.001 --bottleneck-dim 256 -ub 96 --epochs 20 --seed 0 --log logs/fixmatch/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --lr 0.003 --bottleneck-dim 1024 --epochs 20 --seed 0 --log logs/fixmatch/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 --train-resizing cen.crop \ --lr 0.003 --threshold 0.8 --bottleneck-dim 2048 --epochs 20 -ub 64 --seed 0 --per-class-eval --log logs/fixmatch/VisDA2017 ================================================ FILE: examples/domain_adaptation/image_classification/jan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.alignment.jan import JointMultipleKernelMaximumMeanDiscrepancy, ImageClassifier, Theta from tllib.modules.kernels import GaussianKernel from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define loss function if args.adversarial: thetas = [Theta(dim).to(device) for dim in (classifier.features_dim, num_classes)] else: thetas = None jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy( kernels=( [GaussianKernel(alpha=2 ** k) for k in range(-3, 2)], (GaussianKernel(sigma=0.92, track_running_stats=False),) ), linear=args.linear, thetas=thetas ).to(device) parameters = classifier.get_parameters() if thetas is not None: parameters += [{"params": theta.parameters(), 'lr': 0.1} for theta in thetas] # define optimizer optimizer = SGD(parameters, args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, jmmd_loss, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, jmmd_loss: JointMultipleKernelMaximumMeanDiscrepancy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':5.4f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() jmmd_loss.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = jmmd_loss( (f_s, F.softmax(y_s, dim=1)), (f_t, F.softmax(y_t, dim=1)) ) loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='JAN for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--linear', default=False, action='store_true', help='whether use the linear version') parser.add_argument('--adversarial', default=False, action='store_true', help='whether use adversarial theta') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='jan', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/jan.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_D2A CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_W2A CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_A2W CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_A2D CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_D2W CUDA_VISIBLE_DEVICES=0 python jan.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/jan/Office31_W2D # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 --seed 0 --log logs/jan/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python jan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 20 -i 500 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/jan/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python jan.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --log logs/jan/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python jan.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/jan_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/jan_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python jan.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/jan/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python jan.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 1024 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/jan/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python jan.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python jan.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/USPS2MNIST CUDA_VISIBLE_DEVICES=4 python jan.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/jan/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/mcc.py ================================================ """ @author: Ying Jin @contact: sherryying003@gmail.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.self_training.mcc import MinimumClassConfusionLoss, ImageClassifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function mcc_loss = MinimumClassConfusionLoss(temperature=args.temperature) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch train(train_source_iter, train_target_iter, classifier, mcc_loss, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, mcc: MinimumClassConfusionLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = mcc(y_t) loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='MCC for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--temperature', default=2.5, type=float, help='parameter temperature scaling') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.005, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='mcc', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/mcc.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_A2W CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_D2W CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_W2D CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_A2D CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_D2A CUDA_VISIBLE_DEVICES=0 python mcc.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 -i 500 --seed 2 --bottleneck-dim 1024 --log logs/mcc/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --bottleneck-dim 2048 --log logs/mcc/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=5 python mcc.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 30 --seed 0 --lr 0.002 --per-class-eval --temperature 3.0 --train-resizing cen.crop --log logs/mcc/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 30 -b 32 -i 2500 -p 500 --temperature 2.0 --lr 0.005 --bottleneck-dim 2048 --trade-off 10.0 --seed 0 --log logs/mcc/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python mcc.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 --seed 0 --temperature 2.5 --bottleneck-dim 2048 --log logs/mcc/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python mcc.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --log logs/mcc_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --bottleneck-dim 2048 --epochs 30 --seed 0 -b 24 --log logs/mcc_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python mcc.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --bottleneck-dim 2048 --epochs 30 --seed 0 --log logs/mcc/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python mcc.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --bottleneck-dim 2048 --epochs 40 -i 5000 -p 500 --seed 0 --log logs/mcc/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python mcc.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mcc/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/mcd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp from typing import Tuple import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD import torch.utils.data from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.alignment.mcd import ImageClassifierHead, entropy, classifier_discrepancy from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy, ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) G = utils.get_model(args.arch, pretrain=not args.scratch).to(device) # feature extractor # two image classifier heads pool_layer = nn.Identity() if args.no_pool else None F1 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim, pool_layer).to(device) F2 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim, pool_layer).to(device) # define optimizer # the learning rate is fixed according to origin paper optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005) optimizer_f = SGD([ {"params": F1.parameters()}, {"params": F2.parameters()}, ], momentum=0.9, lr=args.lr, weight_decay=0.0005) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') G.load_state_dict(checkpoint['G']) F1.load_state_dict(checkpoint['F1']) F2.load_state_dict(checkpoint['F2']) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(G, F1.pool_layer).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, G, F1, F2, args) print(acc1) return # start training best_acc1 = 0. best_results = None for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g, optimizer_f, epoch, args) # evaluate on validation set results = validate(val_loader, G, F1, F2, args) # remember best acc@1 and save checkpoint torch.save({ 'G': G.state_dict(), 'F1': F1.state_dict(), 'F2': F2.state_dict() }, logger.get_checkpoint_path('latest')) if max(results) > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(results) best_results = results print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results)) # evaluate on test set checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') G.load_state_dict(checkpoint['G']) F1.load_state_dict(checkpoint['F1']) F2.load_state_dict(checkpoint['F2']) results = validate(test_loader, G, F1, F2, args) print("test_acc1 = {:3.1f}".format(max(results))) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead, optimizer_g: SGD, optimizer_f: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode G.train() F1.train() F2.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) x = torch.cat((x_s, x_t), dim=0) assert x.requires_grad is False # measure data loading time data_time.update(time.time() - end) # Step A train all networks to minimize loss on source domain optimizer_g.zero_grad() optimizer_f.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \ (entropy(y1_t) + entropy(y2_t)) * args.trade_off_entropy loss.backward() optimizer_g.step() optimizer_f.step() # Step B train classifier to maximize discrepancy optimizer_g.zero_grad() optimizer_f.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) loss = F.cross_entropy(y1_s, labels_s) + F.cross_entropy(y2_s, labels_s) + \ (entropy(y1_t) + entropy(y2_t)) * args.trade_off_entropy - \ classifier_discrepancy(y1_t, y2_t) * args.trade_off loss.backward() optimizer_f.step() # Step C train genrator to minimize discrepancy for k in range(args.num_k): optimizer_g.zero_grad() g = G(x) y_1 = F1(g) y_2 = F2(g) y1_s, y1_t = y_1.chunk(2, dim=0) y2_s, y2_t = y_2.chunk(2, dim=0) y1_t, y2_t = F.softmax(y1_t, dim=1), F.softmax(y2_t, dim=1) mcd_loss = classifier_discrepancy(y1_t, y2_t) * args.trade_off mcd_loss.backward() optimizer_g.step() cls_acc = accuracy(y1_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) trans_losses.update(mcd_loss.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) def validate(val_loader: DataLoader, G: nn.Module, F1: ImageClassifierHead, F2: ImageClassifierHead, args: argparse.Namespace) -> Tuple[float, float]: batch_time = AverageMeter('Time', ':6.3f') top1_1 = AverageMeter('Acc_1', ':6.2f') top1_2 = AverageMeter('Acc_2', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, top1_1, top1_2], prefix='Test: ') # switch to evaluate mode G.eval() F1.eval() F2.eval() if args.per_class_eval: confmat = ConfusionMatrix(len(args.class_names)) else: confmat = None with torch.no_grad(): end = time.time() for i, data in enumerate(val_loader): images, target = data[:2] images = images.to(device) target = target.to(device) # compute output g = G(images) y1, y2 = F1(g), F2(g) # measure accuracy and record loss acc1, = accuracy(y1, target) acc2, = accuracy(y2, target) if confmat: confmat.update(target, y1.argmax(1)) top1_1.update(acc1.item(), images.size(0)) top1_2.update(acc2.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc1 {top1_1.avg:.3f} Acc2 {top1_2.avg:.3f}' .format(top1_1=top1_1, top1_2=top1_2)) if confmat: print(confmat.format(args.class_names)) return top1_1.avg, top1_2.avg if __name__ == '__main__': parser = argparse.ArgumentParser(description='MCD for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=1024, type=int) parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') parser.add_argument('--trade-off-entropy', default=0.01, type=float, help='the trade-off hyper-parameter for entropy loss') parser.add_argument('--num-k', type=int, default=4, metavar='K', help='how many steps to repeat the generator update') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='mcd', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/mcd.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source # We found MCD loss is sensitive to class number, # thus, when the class number increase, please increase trade-off correspondingly. CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_D2A CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_W2A CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_A2W CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_A2D CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_D2W CUDA_VISIBLE_DEVICES=0 python mcd.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 -i 500 --trade-off 10.0 --log logs/mcd/Office31_W2D # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 20 -i 500 --seed 0 --trade-off 30.0 --log logs/mcd/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python mcd.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 20 --center-crop --seed 0 -i 500 --per-class-eval --train-resizing cen.crop --log logs/mcd/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s c -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t i -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2i CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s p -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s r -t s -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t c -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t p -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python mcd.py data/domainnet -d DomainNet -s s -t r -a resnet101 --bottleneck-dim 1024 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 120.0 --log logs/mcd/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python mcd.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 30 -i 2500 -p 500 --seed 0 --trade-off 100.0 --log logs/mcd/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python mcd.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d --epochs 30 -i 2500 -p 500 --trade-off 500.0 --log logs/mcd_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python mcd.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 30 --seed 0 -b 24 --no-pool --trade-off 30.0 --log logs/mcd_vit/OfficeHome_Rw2Pr # Digits CUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.1 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python mcd.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.03 --trade-off 0.3 --trade-off-entropy 0.03 -b 128 -i 2500 --scratch --seed 0 --log logs/mcd/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/mdd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import os.path as osp import shutil import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.alignment.mdd import ClassificationMarginDisparityDiscrepancy \ as MarginDisparityDiscrepancy, ImageClassifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, random_horizontal_flip=not args.no_hflip, random_color_jitter=False, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, norm_mean=args.norm_mean, norm_std=args.norm_std) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, width=args.bottleneck_dim, pool_layer=pool_layer).to(device) mdd = MarginDisparityDiscrepancy(args.margin).to(device) # define optimizer and lr_scheduler # The learning rate of the classifiers are set 10 times to that of the feature extractor by default. optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, classifier, mdd, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, classifier: ImageClassifier, mdd: MarginDisparityDiscrepancy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode classifier.train() mdd.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter)[:2] x_t, = next(train_target_iter)[:1] x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) outputs, outputs_adv = classifier(x) y_s, y_t = outputs.chunk(2, dim=0) y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0) # compute cross entropy loss on source domain cls_loss = F.cross_entropy(y_s, labels_s) # compute margin disparity discrepancy between domains # for adversarial classifier, minimize negative mdd is equal to maximize mdd transfer_loss = -mdd(y_s, y_s_adv, y_t, y_t_adv) loss = cls_loss + transfer_loss * args.trade_off classifier.step() cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='MDD for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') parser.add_argument('--resize-size', type=int, default=224, help='the image size after resizing') parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--norm-mean', type=float, nargs='+', default=(0.485, 0.456, 0.406), help='normalization mean') parser.add_argument('--norm-std', type=float, nargs='+', default=(0.229, 0.224, 0.225), help='normalization std') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=1024, type=int) parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--margin', type=float, default=4., help="margin gamma") parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.004, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0002, type=float) parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='mdd', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/mdd.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_A2W CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_D2W CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_W2D CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_A2D CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_D2A CUDA_VISIBLE_DEVICES=0 python mdd.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --bottleneck-dim 1024 --seed 1 --log logs/mdd/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python mdd.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 --epochs 30 \ --bottleneck-dim 1024 --seed 0 --train-resizing cen.crop --per-class-eval -b 36 --log logs/mdd/VisDA2017 # ResNet101, DomainNet, Single Source CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2p CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2r CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_c2s CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2c CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2r CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s p -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_p2s CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2c CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2p CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s r -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_r2s CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2c CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2p CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s s -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_s2r # ResNet50, ImageNet200 -> ImageNetR CUDA_VISIBLE_DEVICES=0 python mdd.py data/ImageNetR -d ImageNetR -s IN -t INR -a resnet50 --epochs 40 -i 2500 -p 500 \ --bottleneck-dim 2048 --seed 0 --lr 0.004 --train-resizing cen.crop --log logs/mdd/ImageNet_IN2INR # ig_resnext101_32x8d, ImageNet -> ImageNetSketch CUDA_VISIBLE_DEVICES=0 python mdd.py data/imagenet-sketch -d ImageNetSketch -s IN -t sketch -a ig_resnext101_32x8d \ --epochs 40 -i 2500 -p 500 --bottleneck-dim 2048 --margin 2. --seed 0 --lr 0.004 --train-resizing cen.crop \ --log logs/mdd_ig_resnext101_32x8d/ImageNet_IN2sketch # Vision Transformer, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --epochs 40 --bottleneck-dim 2048 --seed 0 -b 24 --no-pool --log logs/mdd_vit/OfficeHome_Rw2Pr # ResNet50, Office-Home, Multi Source CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Cl Pr Rw -t Ar -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Ar CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Pr Rw -t Cl -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Cl CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Pr CUDA_VISIBLE_DEVICES=0 python mdd.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --epochs 30 --bottleneck-dim 2048 --seed 0 --log logs/mdd/OfficeHome_:2Rw # ResNet101, DomainNet, Multi Source CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2c CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2i CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2p CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2q CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2r CUDA_VISIBLE_DEVICES=0 python mdd.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet101 --epochs 40 -i 5000 -p 500 --bottleneck-dim 2048 --seed 0 --lr 0.004 --log logs/mdd/DomainNet_:2s # Digits CUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s MNIST -t USPS --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/MNIST2USPS CUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s USPS -t MNIST --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 28 --no-hflip --norm-mean 0.5 --norm-std 0.5 -a lenet --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/USPS2MNIST CUDA_VISIBLE_DEVICES=0 python mdd.py data/digits -d Digits -s SVHNRGB -t MNISTRGB --train-resizing 'res.' --val-resizing 'res.' \ --resize-size 32 --no-hflip --norm-mean 0.5 0.5 0.5 --norm-std 0.5 0.5 0.5 -a dtn --no-pool --lr 0.01 -b 128 -i 2500 --scratch --seed 0 --log logs/mdd/SVHN2MNIST ================================================ FILE: examples/domain_adaptation/image_classification/requirements.txt ================================================ timm ================================================ FILE: examples/domain_adaptation/image_classification/self_ensemble.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torchvision.transforms as T import torch.nn.functional as F import utils from tllib.self_training.pi_model import L2ConsistencyLoss from tllib.self_training.mean_teacher import EMATeacher from tllib.self_training.self_ensemble import ClassBalanceLoss, ImageClassifier from tllib.vision.transforms import ResizeImage, MultipleApply from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code # we find self ensemble is sensitive to data augmentation. The following # data augmentation performs well for evaluated datasets normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose([ ResizeImage(256), T.RandomCrop(224), T.RandomHorizontalFlip(), T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5), T.RandomGrayscale(), T.ToTensor(), normalize ]) val_transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), T.ToTensor(), normalize ]) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform, MultipleApply([train_transform, val_transform])) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) # define optimizer and lr scheduler optimizer = Adam(classifier.get_parameters(), args.lr) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return if args.pretrain is None: # first pretrain the classifier wish source data print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrain_model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) pretrain_optimizer = Adam(pretrain_model.get_parameters(), args.pretrain_lr) pretrain_lr_scheduler = LambdaLR(pretrain_optimizer, lambda x: args.pretrain_lr * (1. + args.lr_gamma * float(x)) ** ( -args.lr_decay)) # start pretraining for epoch in range(args.pretrain_epochs): # pretrain for one epoch utils.empirical_risk_minimization(train_source_iter, pretrain_model, pretrain_optimizer, pretrain_lr_scheduler, epoch, args, device) # validate to show pretrain process utils.validate(val_loader, pretrain_model, args, device) torch.save(pretrain_model.state_dict(), args.pretrain) print("Pretraining process is done.") checkpoint = torch.load(args.pretrain, map_location='cpu') classifier.load_state_dict(checkpoint) teacher = EMATeacher(classifier, alpha=args.alpha) consistency_loss = L2ConsistencyLoss().to(device) class_balance_loss = ClassBalanceLoss(num_classes).to(device) # start training best_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, classifier, teacher, consistency_loss, class_balance_loss, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, teacher: EMATeacher, consistency_loss, class_balance_loss, optimizer: Adam, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') cons_losses = AverageMeter('Cons Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, cls_losses, cons_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() teacher.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] (x_t1, x_t2), = next(train_target_iter)[:1] x_s = x_s.to(device) x_t1 = x_t1.to(device) x_t2 = x_t2.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, _ = model(x_s) y_t, _ = model(x_t1) y_t_teacher, _ = teacher(x_t2) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # compute output and mask p_t = F.softmax(y_t, dim=1) p_t_teacher = F.softmax(y_t_teacher, dim=1) confidence, _ = p_t_teacher.max(dim=1) mask = (confidence > args.threshold).float() # consistency loss cons_loss = consistency_loss(p_t, p_t_teacher, mask) # balance loss balance_loss = class_balance_loss(p_t) * mask.mean() loss = cls_loss + args.trade_off_cons * cons_loss + args.trade_off_balance * balance_loss # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # update teacher teacher.update() # update statistics cls_acc = accuracy(y_s, labels_s)[0] cls_losses.update(cls_loss.item(), x_s.size(0)) cons_losses.update(cons_loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Self Ensemble for Unsupervised Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--pretrain', type=str, default=None, help='pretrain checkpoint for classification model') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=1e-4, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--pretrain-lr', '--pretrain-learning-rate', default=3e-5, type=float, help='initial pretrain learning rate', dest='pretrain_lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--alpha', default=0.99, type=float, help='ema decay rate (default: 0.99)') parser.add_argument('--threshold', default=0.8, type=float, help='confidence threshold') parser.add_argument('--trade-off-cons', default=3, type=float, help='trade off parameter for consistency loss') parser.add_argument('--trade-off-balance', default=0.01, type=float, help='trade off parameter for class balance loss') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--pretrain-epochs', default=0, type=int, metavar='N', help='number of total epochs(pretrain) to run') parser.add_argument('--epochs', default=10, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='self_ensemble', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_classification/self_ensemble.sh ================================================ #!/usr/bin/env bash # ResNet50, Office31, Single Source CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s A -t W -a resnet50 --seed 1 --log logs/self_ensemble/Office31_A2W CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s D -t W -a resnet50 --seed 1 --log logs/self_ensemble/Office31_D2W CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s W -t D -a resnet50 --seed 1 --log logs/self_ensemble/Office31_W2D CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s A -t D -a resnet50 --seed 1 --log logs/self_ensemble/Office31_A2D CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s D -t A -a resnet50 --seed 1 --log logs/self_ensemble/Office31_D2A CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office31 -d Office31 -s W -t A -a resnet50 --seed 1 --log logs/self_ensemble/Office31_W2A # ResNet50, Office-Home, Single Source CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --seed 0 --log logs/self_ensemble/OfficeHome_Rw2Pr # ResNet101, VisDA-2017, Single Source CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet101 \ --epochs 20 --seed 0 --per-class-eval --log logs/self_ensemble/VisDA2017 --lr-gamma 0.0002 -b 32 # Office-Home on Vision Transformer CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Ar -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Cl -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Pr -t Rw -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Ar -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Cl -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python self_ensemble.py data/office-home -d OfficeHome -s Rw -t Pr -a vit_base_patch16_224 --no-pool --epochs 30 --seed 0 -b 24 --log logs/self_ensemble_vit/OfficeHome_Rw2Pr ================================================ FILE: examples/domain_adaptation/image_classification/utils.py ================================================ """ @author: Junguang Jiang, Baixu Chen @contact: JiangJunguang1123@outlook.com, cbx_99_hasta@outlook.com """ import sys import os.path as osp import time from PIL import Image import timm import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T from timm.data.auto_augment import auto_augment_transform, rand_augment_transform sys.path.append('../../..') import tllib.vision.datasets as datasets import tllib.vision.models as models from tllib.vision.transforms import ResizeImage from tllib.utils.metric import accuracy, ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.vision.datasets.imagelist import MultipleDomainsDataset def get_model_names(): return sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) + timm.list_models() def get_model(model_name, pretrain=True): if model_name in models.__dict__: # load models from tllib.vision.models backbone = models.__dict__[model_name](pretrained=pretrain) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=pretrain) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() return backbone def get_dataset_names(): return sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) + ['Digits'] def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None): if train_target_transform is None: train_target_transform = train_source_transform if dataset_name == "Digits": train_source_dataset = datasets.__dict__[source[0]](osp.join(root, source[0]), download=True, transform=train_source_transform) train_target_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), download=True, transform=train_target_transform) val_dataset = test_dataset = datasets.__dict__[target[0]](osp.join(root, target[0]), split='test', download=True, transform=val_transform) class_names = datasets.MNIST.get_classes() num_classes = len(class_names) elif dataset_name in datasets.__dict__: # load datasets from tllib.vision.datasets dataset = datasets.__dict__[dataset_name] def concat_dataset(tasks, start_idx, **kwargs): # return ConcatDataset([dataset(task=task, **kwargs) for task in tasks]) return MultipleDomainsDataset([dataset(task=task, **kwargs) for task in tasks], tasks, domain_ids=list(range(start_idx, start_idx + len(tasks)))) train_source_dataset = concat_dataset(root=root, tasks=source, download=True, transform=train_source_transform, start_idx=0) train_target_dataset = concat_dataset(root=root, tasks=target, download=True, transform=train_target_transform, start_idx=len(source)) val_dataset = concat_dataset(root=root, tasks=target, download=True, transform=val_transform, start_idx=len(source)) if dataset_name == 'DomainNet': test_dataset = concat_dataset(root=root, tasks=target, split='test', download=True, transform=val_transform, start_idx=len(source)) else: test_dataset = val_dataset class_names = train_source_dataset.datasets[0].classes num_classes = len(class_names) else: raise NotImplementedError(dataset_name) return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names def validate(val_loader, model, args, device) -> float: batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1], prefix='Test: ') # switch to evaluate mode model.eval() if args.per_class_eval: confmat = ConfusionMatrix(len(args.class_names)) else: confmat = None with torch.no_grad(): end = time.time() for i, data in enumerate(val_loader): images, target = data[:2] images = images.to(device) target = target.to(device) # compute output output = model(images) loss = F.cross_entropy(output, target) # measure accuracy and record loss acc1, = accuracy(output, target, topk=(1,)) if confmat: confmat.update(target, output.argmax(1)) losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) if confmat: print(confmat.format(args.class_names)) return top1.avg def get_train_transform(resizing='default', scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), random_horizontal_flip=True, random_color_jitter=False, resize_size=224, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225), auto_augment=None): """ resizing mode: - default: resize the image to 256 and take a random resized crop of size 224; - cen.crop: resize the image to 256 and take the center crop of size 224; - res: resize the image to 224; """ transformed_img_size = 224 if resizing == 'default': transform = T.Compose([ ResizeImage(256), T.RandomResizedCrop(224, scale=scale, ratio=ratio) ]) elif resizing == 'cen.crop': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224) ]) elif resizing == 'ran.crop': transform = T.Compose([ ResizeImage(256), T.RandomCrop(224) ]) elif resizing == 'res.': transform = ResizeImage(resize_size) transformed_img_size = resize_size else: raise NotImplementedError(resizing) transforms = [transform] if random_horizontal_flip: transforms.append(T.RandomHorizontalFlip()) if auto_augment: aa_params = dict( translate_const=int(transformed_img_size * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in norm_mean]), interpolation=Image.BILINEAR ) if auto_augment.startswith('rand'): transforms.append(rand_augment_transform(auto_augment, aa_params)) else: transforms.append(auto_augment_transform(auto_augment, aa_params)) elif random_color_jitter: transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)) transforms.extend([ T.ToTensor(), T.Normalize(mean=norm_mean, std=norm_std) ]) return T.Compose(transforms) def get_val_transform(resizing='default', resize_size=224, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)): """ resizing mode: - default: resize the image to 256 and take the center crop of size 224; – res.: resize the image to 224 """ if resizing == 'default': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), ]) elif resizing == 'res.': transform = ResizeImage(resize_size) else: raise NotImplementedError(resizing) return T.Compose([ transform, T.ToTensor(), T.Normalize(mean=norm_mean, std=norm_std) ]) def empirical_risk_minimization(train_source_iter, model, optimizer, lr_scheduler, epoch, args, device): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter)[:2] x_s = x_s.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) cls_loss = F.cross_entropy(y_s, labels_s) loss = cls_loss cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) ================================================ FILE: examples/domain_adaptation/image_regression/README.md ================================================ # Unsupervised Domain Adaptation for Image Regression Tasks It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to better reproduce the benchmark results. ## Dataset Following datasets can be downloaded automatically: - [DSprites](https://github.com/deepmind/dsprites-dataset) - [MPI3D](https://github.com/rr-learning/disentanglement_dataset) ## Supported Methods Supported methods include: - [Disparity Discrepancy (DD)](https://arxiv.org/abs/1904.05801) - [Representation Subspace Distance (RSD)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Representation-Subspace-Distance-for-Domain-Adaptation-Regression-icml21.pdf) ## Experiment and Results The shell files give the script to reproduce the benchmark results with specified hyper-parameters. For example, if you want to train DD on DSprites, use the following script ```shell script # Train a DD on DSprites C->N task using ResNet 18. # Assume you have put the datasets under the path `data/dSprites`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/mdd/dSprites_C2N --wd 0.0005 ``` **Notations** - ``Origin`` means the accuracy reported by the original paper. - ``Avg`` is the accuracy reported by Transfer-Learn. - ``ERM`` refers to the model trained with data from the source domain. - ``Oracle`` refers to the model trained with data from the target domain. Labels are all normalized to [0, 1] to eliminate the effects of diverse scale in regression values. We repeat experiments on DD for three times and report the average error of the ``final`` epoch. ### dSprites error on ResNet-18 | Methods | Avg | C → N | C → S | N → C | N → S | S → C | S → N | |-------------|-------|-------|-------|-------|-------|-------|-------| | ERM | 0.157 | 0.232 | 0.271 | 0.081 | 0.22 | 0.038 | 0.092 | | DD | 0.057 | 0.047 | 0.08 | 0.03 | 0.095 | 0.053 | 0.037 | ### MPI3D error on ResNet-18 | Methods | Avg | RL → RC | RL → T | RC → RL | RC → T | T → RL | T → RC | |-------------|-------|---------|--------|---------|--------|--------|--------| | ERM | 0.176 | 0.232 | 0.271 | 0.081 | 0.22 | 0.038 | 0.092 | | DD | 0.03 | 0.086 | 0.029 | 0.057 | 0.189 | 0.131 | 0.087 | ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{MDD, title={Bridging theory and algorithm for domain adaptation}, author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael}, booktitle={ICML}, year={2019}, } @inproceedings{RSD, title={Representation Subspace Distance for Domain Adaptation Regression}, author={Chen, Xinyang and Wang, Sinan and Wang, Jianmin and Long, Mingsheng}, booktitle={ICML}, year={2021} } ``` ================================================ FILE: examples/domain_adaptation/image_regression/dann.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torchvision.transforms as T import torch.nn.functional as F import utils from tllib.modules.regressor import Regressor from tllib.alignment.dann import DomainAdversarialLoss from tllib.modules.domain_discriminator import DomainDiscriminator import tllib.vision.datasets.regression as datasets import tllib.vision.models as models from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose([ T.Resize(args.resize_size), T.ToTensor(), normalize ]) val_transform = T.Compose([ T.Resize(args.resize_size), T.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) if args.normalization == 'IN': backbone = utils.convert_model(backbone) num_factors = train_source_dataset.num_factors bottleneck = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(backbone.out_features, 256), nn.ReLU() ) regressor = Regressor(backbone=backbone, num_factors=num_factors, bottleneck=bottleneck, bottleneck_dim=256).to( device) print(regressor) domain_discri = DomainDiscriminator(in_feature=regressor.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(regressor.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function dann = DomainAdversarialLoss(domain_discri).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') regressor.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) # extract features from both domains feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device) print(mae) return # start training best_mae = 100000. for epoch in range(args.epochs): # train for one epoch print("lr", lr_scheduler.get_lr()) train(train_source_iter, train_target_iter, regressor, dann, optimizer, lr_scheduler, epoch, args) # evaluate on validation set mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device) # remember best mae and save checkpoint torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest')) if mae < best_mae: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_mae = min(mae, best_mae) print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae)) print("best_mae = {:6.3f}".format(best_mae)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Regressor, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') mse_losses = AverageMeter('MSE Loss', ':6.3f') dann_losses = AverageMeter('DANN Loss', ':6.3f') domain_accs = AverageMeter('Domain Acc', ':3.1f') mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f') mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, mse_losses, dann_losses, mae_losses_s, mae_losses_t, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device).float() x_t, labels_t = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) mse_loss = F.mse_loss(y_s, labels_s) mae_loss_s = F.l1_loss(y_s, labels_s) mae_loss_t = F.l1_loss(y_t, labels_t) transfer_loss = domain_adv(f_s, f_t) loss = mse_loss + transfer_loss * args.trade_off domain_acc = domain_adv.domain_discriminator_accuracy mse_losses.update(mse_loss.item(), x_s.size(0)) dann_losses.update(transfer_loss.item(), x_s.size(0)) mae_losses_s.update(mae_loss_s.item(), x_s.size(0)) mae_losses_t.update(mae_loss_t.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='DANN for Regression Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='DSprites', help='dataset: ' + ' | '.join(dataset_names) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-size', type=int, default=128) # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: resnet18)') parser.add_argument('--normalization', default='BN', type=str, choices=["BN", "IN"]) parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.001, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='dann', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_regression/dann.sh ================================================ # DSprites CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_C2N CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_C2S CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_N2C CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_N2S CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_S2C CUDA_VISIBLE_DEVICES=0 python dann.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 --log logs/dann/DSprites_S2N # MPI3D CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RL2RC --resize-size 224 CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RL2T --resize-size 224 CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RC2RL --resize-size 224 CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_RC2T --resize-size 224 CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_T2RL --resize-size 224 CUDA_VISIBLE_DEVICES=0 python dann.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 --log logs/dann/MPI3D_T2RC --resize-size 224 ================================================ FILE: examples/domain_adaptation/image_regression/dd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torchvision.transforms as T import torch.nn.functional as F import torch.nn as nn import utils from tllib.alignment.mdd import RegressionMarginDisparityDiscrepancy as MarginDisparityDiscrepancy, ImageRegressor import tllib.vision.datasets.regression as datasets import tllib.vision.models as models from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose([ T.Resize(args.resize_size), T.ToTensor(), normalize ]) val_transform = T.Compose([ T.Resize(args.resize_size), T.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) num_factors = train_source_dataset.num_factors backbone = models.__dict__[args.arch](pretrained=True) bottleneck_dim = args.bottleneck_dim if args.normalization == 'IN': backbone = utils.convert_model(backbone) bottleneck = nn.Sequential( nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), ) head = nn.Sequential( nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(bottleneck_dim, num_factors), nn.Sigmoid() ) for layer in head: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, 0, 0.01) nn.init.constant_(layer.bias, 0) adv_head = nn.Sequential( nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(bottleneck_dim), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(bottleneck_dim, num_factors), nn.Sigmoid() ) for layer in adv_head: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, 0, 0.01) nn.init.constant_(layer.bias, 0) regressor = ImageRegressor(backbone, num_factors, bottleneck=bottleneck, head=head, adv_head=adv_head, bottleneck_dim=bottleneck_dim, width=bottleneck_dim) else: regressor = ImageRegressor(backbone, num_factors, bottleneck_dim=bottleneck_dim, width=bottleneck_dim) regressor = regressor.to(device) print(regressor) mdd = MarginDisparityDiscrepancy(args.margin).to(device) # define optimizer and lr scheduler optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') regressor.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) # extract features from both domains feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck, regressor.head[:-2]).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device) print(mae) return # start training best_mae = 100000. for epoch in range(args.epochs): # train for one epoch print("lr", lr_scheduler.get_lr()) train(train_source_iter, train_target_iter, regressor, mdd, optimizer, lr_scheduler, epoch, args) # evaluate on validation set mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device) # remember best mae and save checkpoint torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest')) if mae < best_mae: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_mae = min(mae, best_mae) print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae)) print("best_mae = {:6.3f}".format(best_mae)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, mdd: MarginDisparityDiscrepancy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') source_losses = AverageMeter('Source Loss', ':6.3f') trans_losses = AverageMeter('Trans Loss', ':6.3f') mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f') mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, source_losses, trans_losses, mae_losses_s, mae_losses_t], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() mdd.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device).float() x_t, labels_t = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat([x_s, x_t], dim=0) outputs, outputs_adv = model(x) y_s, y_t = outputs.chunk(2, dim=0) y_s_adv, y_t_adv = outputs_adv.chunk(2, dim=0) # compute mean square loss on source domain mse_loss = F.mse_loss(y_s, labels_s) # compute margin disparity discrepancy between domains transfer_loss = mdd(y_s, y_s_adv, y_t, y_t_adv) # for adversarial classifier, minimize negative mdd is equal to maximize mdd loss = mse_loss - transfer_loss * args.trade_off model.step() mae_loss_s = F.l1_loss(y_s, labels_s) mae_loss_t = F.l1_loss(y_t, labels_t) source_losses.update(mse_loss.item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) mae_losses_s.update(mae_loss_s.item(), x_s.size(0)) mae_losses_t.update(mae_loss_t.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='DD for Regression Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='DSprites', help='dataset: ' + ' | '.join(dataset_names) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-size', type=int, default=128) # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: resnet18)') parser.add_argument('--bottleneck-dim', default=512, type=int) parser.add_argument('--normalization', default='BN', type=str, choices=['BN', 'IN']) parser.add_argument('--margin', type=float, default=1., help="margin gamma") parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='dd', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_regression/dd.sh ================================================ # DSprites CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_C2N --wd 0.0005 CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_C2S --wd 0.0005 CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_N2C --wd 0.0005 CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_N2S --wd 0.0005 CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_S2C --wd 0.0005 CUDA_VISIBLE_DEVICES=0 python dd.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 -b 128 --log logs/dd/dSprites_S2N --wd 0.0005 # MPI3D CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RL2RC --normalization IN --resize-size 224 --weight-decay 0.001 CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RL2T --normalization IN --resize-size 224 --weight-decay 0.001 CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RC2RL --normalization IN --resize-size 224 --weight-decay 0.001 CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_RC2T --normalization IN --resize-size 224 --weight-decay 0.001 CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_T2RL --normalization IN --resize-size 224 --weight-decay 0.001 CUDA_VISIBLE_DEVICES=0 python dd.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 60 --seed 0 -b 36 --log logs/dd/MPI3D_T2RC --normalization IN --resize-size 224 --weight-decay 0.001 ================================================ FILE: examples/domain_adaptation/image_regression/erm.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torchvision.transforms as T import torch.nn.functional as F import utils from tllib.modules.regressor import Regressor import tllib.vision.datasets.regression as datasets import tllib.vision.models as models from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose([ T.Resize(args.resize_size), T.ToTensor(), normalize ]) val_transform = T.Compose([ T.Resize(args.resize_size), T.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) if args.normalization == 'IN': backbone = utils.convert_model(backbone) num_factors = train_source_dataset.num_factors regressor = Regressor(backbone=backbone, num_factors=num_factors).to(device) print(regressor) # define optimizer and lr scheduler optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') regressor.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) # extract features from both domains feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device) print(mae) return # start training best_mae = 100000. for epoch in range(args.epochs): # train for one epoch print("lr", lr_scheduler.get_lr()) train(train_source_iter, train_target_iter, regressor, optimizer, lr_scheduler, epoch, args) # evaluate on validation set mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device) # remember best mae and save checkpoint torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest')) if mae < best_mae: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_mae = min(mae, best_mae) print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae)) print("best_mae = {:6.3f}".format(best_mae)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Regressor, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') mse_losses = AverageMeter('MSE Loss', ':6.3f') mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f') mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, mse_losses, mae_losses_s, mae_losses_t], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device).float() x_t, labels_t = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output y_s, _ = model(x_s) y_t, _ = model(x_t) mse_loss = F.mse_loss(y_s, labels_s) mae_loss_s = F.l1_loss(y_s, labels_s) mae_loss_t = F.l1_loss(y_t, labels_t) loss = mse_loss mse_losses.update(mse_loss.item(), x_s.size(0)) mae_losses_s.update(mae_loss_s.item(), x_s.size(0)) mae_losses_t.update(mae_loss_t.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='Source Only for Regression Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='DSprites', help='dataset: ' + ' | '.join(dataset_names) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-size', type=int, default=128) # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: resnet18)') parser.add_argument('--normalization', default='BN', type=str, choices=["IN", "BN"]) # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='src_only', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_regression/erm.sh ================================================ # DSprites CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_C2N CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_C2S CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_N2C CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_N2S CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_S2C CUDA_VISIBLE_DEVICES=0 python erm.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 20 --seed 0 -b 128 --log logs/erm/DSprites_S2N # MPI3D CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RL2RC CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RL2T CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RC2RL CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_RC2T CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_T2RL CUDA_VISIBLE_DEVICES=0 python erm.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 -b 36 --log logs/erm/MPI3D_T2RC ================================================ FILE: examples/domain_adaptation/image_regression/rsd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torchvision.transforms as T import torch.nn.functional as F import utils from tllib.modules.regressor import Regressor from tllib.alignment.rsd import RepresentationSubspaceDistance import tllib.vision.datasets.regression as datasets import tllib.vision.models as models from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = T.Compose([ T.Resize(args.resize_size), T.ToTensor(), normalize ]) val_transform = T.Compose([ T.Resize(args.resize_size), T.ToTensor(), normalize ]) dataset = datasets.__dict__[args.data] train_source_dataset = dataset(root=args.root, task=args.source, split='train', download=True, transform=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_dataset = dataset(root=args.root, task=args.target, split='train', download=True, transform=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset = dataset(root=args.root, task=args.target, split='test', download=True, transform=val_transform) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](pretrained=True) if args.normalization == 'IN': backbone = utils.convert_model(backbone) num_factors = train_source_dataset.num_factors bottleneck = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(backbone.out_features, 256), nn.ReLU() ) regressor = Regressor(backbone=backbone, num_factors=num_factors, bottleneck=bottleneck, bottleneck_dim=256).to(device) print(regressor) # define optimizer and lr scheduler optimizer = SGD(regressor.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function rsd = RepresentationSubspaceDistance(args.trade_off_bmp) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') regressor.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) # extract features from both domains feature_extractor = nn.Sequential(regressor.backbone, regressor.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device) print(mae) return # start training best_mae = 100000. for epoch in range(args.epochs): # train for one epoch print("lr", lr_scheduler.get_lr()) train(train_source_iter, train_target_iter, regressor, rsd, optimizer, lr_scheduler, epoch, args) # evaluate on validation set mae = utils.validate(val_loader, regressor, args, train_source_dataset.factors, device) # remember best mae and save checkpoint torch.save(regressor.state_dict(), logger.get_checkpoint_path('latest')) if mae < best_mae: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_mae = min(mae, best_mae) print("mean MAE {:6.3f} best MAE {:6.3f}".format(mae, best_mae)) print("best_mae = {:6.3f}".format(best_mae)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Regressor, rsd, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') mse_losses = AverageMeter('MSE Loss', ':6.3f') rsd_losses = AverageMeter('RSD Loss', ':6.3f') mae_losses_s = AverageMeter('MAE Loss (s)', ':6.3f') mae_losses_t = AverageMeter('MAE Loss (t)', ':6.3f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, mse_losses, rsd_losses, mae_losses_s, mae_losses_t], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device).float() x_t, labels_t = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) mse_loss = F.mse_loss(y_s, labels_s) mae_loss_s = F.l1_loss(y_s, labels_s) mae_loss_t = F.l1_loss(y_t, labels_t) rsd_loss = rsd(f_s, f_t) loss = mse_loss + rsd_loss * args.trade_off mse_losses.update(mse_loss.item(), x_s.size(0)) rsd_losses.update(rsd_loss.item(), x_s.size(0)) mae_losses_s.update(mae_loss_s.item(), x_s.size(0)) mae_losses_t.update(mae_loss_t.item(), x_s.size(0)) # compute gradient and do SGD step loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='RSD for Regression Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='DSprites', help='dataset: ' + ' | '.join(dataset_names) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-size', type=int, default=128) # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: resnet18)') parser.add_argument('--normalization', default='BN', type=str, choices=["BN", "IN"]) parser.add_argument('--trade-off', default=0.001, type=float) parser.add_argument('--trade-off-bmp', default=0.1, type=float) # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.001, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='rsd', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/image_regression/rsd.sh ================================================ # DSprites CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s C -t N -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_C2N CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s C -t S -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_C2S CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s N -t C -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_N2C CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s N -t S -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_N2S CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s S -t C -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_S2C CUDA_VISIBLE_DEVICES=0 python rsd.py data/dSprites -d DSprites -s S -t N -a resnet18 --epochs 40 --seed 0 --log logs/rsd/DSprites_S2N # MPI3D CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RL -t RC -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RL2RC --resize-size 224 CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RL -t T -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RL2T --resize-size 224 CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RC -t RL -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RC2RL --resize-size 224 CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s RC -t T -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_RC2T --resize-size 224 CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s T -t RL -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_T2RL --resize-size 224 CUDA_VISIBLE_DEVICES=0 python rsd.py data/mpi3d -d MPI3D -s T -t RC -a resnet18 --epochs 40 --seed 0 --log logs/rsd/MPI3D_T2RC --resize-size 224 ================================================ FILE: examples/domain_adaptation/image_regression/utils.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import sys import time import torch import torch.nn.functional as F from torch.nn.modules.batchnorm import BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.nn.modules.instancenorm import InstanceNorm1d, InstanceNorm2d, InstanceNorm3d sys.path.append('../../..') from tllib.utils.meter import AverageMeter, ProgressMeter def convert_model(module): """convert BatchNorms in the `module` into InstanceNorms""" source_modules = (BatchNorm1d, BatchNorm2d, BatchNorm3d) target_modules = (InstanceNorm1d, InstanceNorm2d, InstanceNorm3d) for src_module, tgt_module in zip(source_modules, target_modules): if isinstance(module, src_module): mod = tgt_module(module.num_features, module.eps, module.momentum, module.affine) module = mod for name, child in module.named_children(): module.add_module(name, convert_model(child)) return module def validate(val_loader, model, args, factors, device): batch_time = AverageMeter('Time', ':6.3f') mae_losses = [AverageMeter('mae {}'.format(factor), ':6.3f') for factor in factors] progress = ProgressMeter( len(val_loader), [batch_time] + mae_losses, prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output = model(images) for j in range(len(factors)): mae_loss = F.l1_loss(output[:, j], target[:, j]) mae_losses[j].update(mae_loss.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) for i, factor in enumerate(factors): print("{} MAE {mae.avg:6.3f}".format(factor, mae=mae_losses[i])) mean_mae = sum(l.avg for l in mae_losses) / len(factors) return mean_mae ================================================ FILE: examples/domain_adaptation/keypoint_detection/README.md ================================================ # Unsupervised Domain Adaptation for Keypoint Detection It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to better reproduce the benchmark results. ## Dataset Following datasets can be downloaded automatically: - [Rendered Handpose Dataset](https://lmb.informatik.uni-freiburg.de/resources/datasets/RenderedHandposeDataset.en.html) - [Hand-3d-Studio Dataset](https://www.yangangwang.com/papers/ZHAO-H3S-2020-02.html) - [FreiHAND Dataset](https://lmb.informatik.uni-freiburg.de/projects/freihand/) - [Surreal Dataset](https://www.di.ens.fr/willow/research/surreal/data/) - [LSP Dataset](http://sam.johnson.io/research/lsp.html) You need to prepare following datasets manually if you want to use them: - [Human3.6M Dataset](http://vision.imar.ro/human3.6m/description.php) and prepare them following [Documentations for Human3.6M Dataset](/common/vision/datasets/keypoint_detection/human36m.py). ## Supported Methods Supported methods include: - [Regressive Domain Adaptation for Unsupervised Keypoint Detection (RegDA, CVPR 2021)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/regressive-domain-adaptation-cvpr21.pdf) ## Experiment and Results The shell files give the script to reproduce the results with specified hyper-parameters. For example, if you want to train RegDA on RHD->H3D, use the following script ```shell script # Train a RegDA on RHD -> H3D task using PoseResNet. # Assume you have put the datasets under the path `data/RHD` and `data/H3D_crop`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python regda.py data/RHD data/H3D_crop \ -s RenderedHandPose -t Hand3DStudio --finetune --seed 0 --debug --log logs/regda/rhd2h3d ``` ### RHD->H3D accuracy on ResNet-101 | Methods | MCP | PIP | DIP | Fingertip | Avg | |-------------|------|------|------|-----------|------| | ERM | 67.4 | 64.2 | 63.3 | 54.8 | 61.8 | | RegDA | 79.6 | 74.4 | 71.2 | 62.9 | 72.5 | | Oracle | 97.7 | 97.2 | 95.7 | 92.5 | 95.8 | ### Surreal->Human3.6M accuracy on ResNet-101 | Methods | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Avg | |-------------|----------|-------|-------|------|------|-------|------| | ERM | 69.4 | 75.4 | 66.4 | 37.9 | 77.3 | 77.7 | 67.3 | | RegDA | 73.3 | 86.4 | 72.8 | 54.8 | 82.0 | 84.4 | 75.6 | | Oracle | 95.3 | 91.8 | 86.9 | 95.6 | 94.1 | 93.6 | 92.9 | ### Surreal->LSP accuracy on ResNet-101 | Methods | Shoulder | Elbow | Wrist | Hip | Knee | Ankle | Avg | |-------------|----------|-------|-------|------|------|-------|------| | ERM | 51.5 | 65.0 | 62.9 | 68.0 | 68.7 | 67.4 | 63.9 | | RegDA | 62.7 | 76.7 | 71.1 | 81.0 | 80.3 | 75.3 | 74.6 | | Oracle | 95.3 | 91.8 | 86.9 | 95.6 | 94.1 | 93.6 | 92.9 | ## Visualization If you want to visualize the keypoint detection results during training, you should set --debug. ``` CUDA_VISIBLE_DEVICES=0 python erm.py data/RHD data/H3D_crop -s RenderedHandPose -t Hand3DStudio --log logs/erm/rhd2h3d --debug --seed 0 ``` Then you can find visualization images in directory ``logs/erm/rhd2h3d/visualize/``. ## TODO Support methods: CycleGAN ## Citation If you use these methods in your research, please consider citing. ``` @InProceedings{RegDA, author = {Junguang Jiang and Yifei Ji and Ximei Wang and Yufeng Liu and Jianmin Wang and Mingsheng Long}, title = {Regressive Domain Adaptation for Unsupervised Keypoint Detection}, booktitle = {CVPR}, year = {2021} } ``` ================================================ FILE: examples/domain_adaptation/keypoint_detection/erm.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import torch import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.optim.lr_scheduler import MultiStepLR from torch.utils.data import DataLoader from torchvision.transforms import Compose, ToPILImage sys.path.append('../../..') import tllib.vision.models.keypoint_detection as models from tllib.vision.models.keypoint_detection.loss import JointsMSELoss import tllib.vision.datasets.keypoint_detection as datasets import tllib.vision.transforms.keypoint_detection as T from tllib.vision.transforms import Denormalize from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict from tllib.utils.metric.keypoint_detection import accuracy from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_transform = T.Compose([ T.RandomRotation(args.rotation), T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale), T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25), T.GaussianBlur(), T.ToTensor(), normalize ]) val_transform = T.Compose([ T.Resize(args.image_size), T.ToTensor(), normalize ]) image_size = (args.image_size, args.image_size) heatmap_size = (args.heatmap_size, args.heatmap_size) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) print("Source train:", len(train_source_loader)) print("Target train:", len(train_target_loader)) print("Source test:", len(val_source_loader)) print("Target test:", len(val_target_loader)) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model model = models.__dict__[args.arch](num_keypoints=train_source_dataset.num_keypoints).to(device) criterion = JointsMSELoss() # define optimizer and lr scheduler optimizer = Adam(model.get_parameters(lr=args.lr)) lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor) # optionally resume from a checkpoint start_epoch = 0 if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) start_epoch = checkpoint['epoch'] + 1 # define visualization function tensor_to_image = Compose([ Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ToPILImage() ]) def visualize(image, keypoint2d, name): """ Args: image (tensor): image in shape 3 x H x W keypoint2d (tensor): keypoints in shape K x 2 name: name of the saving image """ train_source_dataset.visualize(tensor_to_image(image), keypoint2d, logger.get_image_path("{}.jpg".format(name))) if args.phase == 'test': # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize, args) print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'], target_val_acc['all'])) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) return # start training best_acc = 0 for epoch in range(start_epoch, args.epochs): logger.set_epoch(epoch) lr_scheduler.step() # train for one epoch train(train_source_iter, train_target_iter, model, criterion, optimizer, epoch, visualize if args.debug else None, args) # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args) # remember best acc and save checkpoint torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if target_val_acc['all'] > best_acc: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_acc = target_val_acc['all'] print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(source_val_acc['all'], target_val_acc['all'], best_acc)) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) logger.close() def train(train_source_iter, train_target_iter, model, criterion, optimizer, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ":.2e") acc_s = AverageMeter("Acc (s)", ":3.2f") progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, acc_s], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, label_s, weight_s, meta_s = next(train_source_iter) x_s = x_s.to(device) label_s = label_s.to(device) weight_s = weight_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s = model(x_s) loss_s = criterion(y_s, label_s, weight_s) # compute gradient and do SGD step loss_s.backward() optimizer.step() # measure accuracy and record loss _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s.update(avg_acc_s, cnt_s) losses_s.update(loss_s, cnt_s) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred.jpg".format(i)) visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label.jpg".format(i)) def validate(val_loader, model, criterion, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.2e') acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f") progress = ProgressMeter( len(val_loader), [batch_time, losses, acc['all']], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (x, label, weight, meta) in enumerate(val_loader): x = x.to(device) label = label.to(device) weight = weight.to(device) # compute output y = model(x) loss = criterion(y, label, weight) # measure accuracy and record loss losses.update(loss.item(), x.size(0)) acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(), label.cpu().numpy()) group_acc = val_loader.dataset.group_accuracy(acc_per_points) acc.update(group_acc, x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x[0], pred[0] * args.image_size / args.heatmap_size, "val_{}_pred.jpg".format(i)) visualize(x[0], meta['keypoint2d'][0], "val_{}_label.jpg".format(i)) return acc.average() if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='Source Only for Keypoint Detection Domain Adaptation') # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3), help='scale range for the RandomResizeCrop augmentation') parser.add_argument('--rotation', type=int, default=180, help='rotation range of the RandomRotation augmentation') parser.add_argument('--image-size', type=int, default=256, help='input image size') parser.add_argument('--heatmap-size', type=int, default=64, help='output heatmap size') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='pose_resnet101', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: pose_resnet101)') parser.add_argument("--resume", type=str, default=None, help="where restore model parameters from.") # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler') parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=70, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='src_only', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--debug', action="store_true", help='In the debug mode, save images and predictions') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/keypoint_detection/erm.sh ================================================ # Source Only # Hands Dataset CUDA_VISIBLE_DEVICES=0 python erm.py data/RHD data/H3D_crop \ -s RenderedHandPose -t Hand3DStudio --log logs/erm/rhd2h3d --debug --seed 0 CUDA_VISIBLE_DEVICES=0 python erm.py data/FreiHand data/RHD \ -s FreiHand -t RenderedHandPose --log logs/erm/freihand2rhd --debug --seed 0 # Body Dataset CUDA_VISIBLE_DEVICES=0 python erm.py data/surreal_processed data/Human36M \ -s SURREAL -t Human36M --log logs/erm/surreal2human36m --debug --seed 0 --rotation 30 CUDA_VISIBLE_DEVICES=0 python erm.py data/surreal_processed data/lsp \ -s SURREAL -t LSP --log logs/erm/surreal2lsp --debug --seed 0 --rotation 30 # Oracle Results CUDA_VISIBLE_DEVICES=0 python erm.py data/H3D_crop data/H3D_crop \ -s Hand3DStudio -t Hand3DStudio --log logs/oracle/h3d --debug --seed 0 CUDA_VISIBLE_DEVICES=0 python erm.py data/Human36M data/Human36M \ -s Human36M -t Human36M --log logs/oracle/human36m --debug --seed 0 --rotation 30 ================================================ FILE: examples/domain_adaptation/keypoint_detection/regda.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import torch import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR, MultiStepLR from torch.utils.data import DataLoader from torchvision.transforms import Compose, ToPILImage sys.path.append('../../..') from tllib.alignment.regda import PoseResNet2d as RegDAPoseResNet, \ PseudoLabelGenerator2d, RegressionDisparity import tllib.vision.models as models from tllib.vision.models.keypoint_detection.pose_resnet import Upsampling, PoseResNet from tllib.vision.models.keypoint_detection.loss import JointsKLLoss import tllib.vision.datasets.keypoint_detection as datasets import tllib.vision.transforms.keypoint_detection as T from tllib.vision.transforms import Denormalize from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict from tllib.utils.metric.keypoint_detection import accuracy from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_transform = T.Compose([ T.RandomRotation(args.rotation), T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale), T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25), T.GaussianBlur(), T.ToTensor(), normalize ]) val_transform = T.Compose([ T.Resize(args.image_size), T.ToTensor(), normalize ]) image_size = (args.image_size, args.image_size) heatmap_size = (args.heatmap_size, args.heatmap_size) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) print("Source train:", len(train_source_loader)) print("Target train:", len(train_target_loader)) print("Source test:", len(val_source_loader)) print("Target test:", len(val_target_loader)) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model backbone = models.__dict__[args.arch](pretrained=True) upsampling = Upsampling(backbone.out_features) num_keypoints = train_source_dataset.num_keypoints model = RegDAPoseResNet(backbone, upsampling, 256, num_keypoints, num_head_layers=args.num_head_layers, finetune=True).to(device) # define loss function criterion = JointsKLLoss() pseudo_label_generator = PseudoLabelGenerator2d(num_keypoints, args.heatmap_size, args.heatmap_size) regression_disparity = RegressionDisparity(pseudo_label_generator, JointsKLLoss(epsilon=1e-7)) # define optimizer and lr scheduler optimizer_f = SGD([ {'params': backbone.parameters(), 'lr': 0.1}, {'params': upsampling.parameters(), 'lr': 0.1}, ], lr=0.1, momentum=args.momentum, weight_decay=args.wd, nesterov=True) optimizer_h = SGD(model.head.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True) optimizer_h_adv = SGD(model.head_adv.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay) lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function) lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function) lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function) start_epoch = 0 if args.resume is None: if args.pretrain is None: # first pretrain the backbone and upsampling print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrained_model = PoseResNet(backbone, upsampling, 256, num_keypoints, True).to(device) optimizer = SGD(pretrained_model.get_parameters(lr=args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor) best_acc = 0 for epoch in range(args.pretrain_epochs): lr_scheduler.step() print(lr_scheduler.get_lr()) pretrain(train_source_iter, pretrained_model, criterion, optimizer, epoch, args) source_val_acc = validate(val_source_loader, pretrained_model, criterion, None, args) # remember best acc and save checkpoint if source_val_acc['all'] > best_acc: best_acc = source_val_acc['all'] torch.save( { 'model': pretrained_model.state_dict() }, args.pretrain ) print("Source: {} best: {}".format(source_val_acc['all'], best_acc)) # load from the pretrained checkpoint pretrained_dict = torch.load(args.pretrain, map_location='cpu')['model'] model_dict = model.state_dict() # remove keys from pretrained dict that doesn't appear in model dict pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model.load_state_dict(pretrained_dict, strict=False) else: # optionally resume from a checkpoint checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer_f.load_state_dict(checkpoint['optimizer_f']) optimizer_h.load_state_dict(checkpoint['optimizer_h']) optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv']) lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f']) lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h']) lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv']) start_epoch = checkpoint['epoch'] + 1 # define visualization function tensor_to_image = Compose([ Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ToPILImage() ]) def visualize(image, keypoint2d, name, heatmaps=None): """ Args: image (tensor): image in shape 3 x H x W keypoint2d (tensor): keypoints in shape K x 2 name: name of the saving image """ train_source_dataset.visualize(tensor_to_image(image), keypoint2d, logger.get_image_path("{}.jpg".format(name))) if args.phase == 'test': # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize, args) print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'], target_val_acc['all'])) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) return # start training best_acc = 0 print("Start regression domain adaptation.") for epoch in range(start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(), lr_scheduler_h_adv.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, criterion, regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch, visualize if args.debug else None, args) # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args) # remember best acc and save checkpoint torch.save( { 'model': model.state_dict(), 'optimizer_f': optimizer_f.state_dict(), 'optimizer_h': optimizer_h.state_dict(), 'optimizer_h_adv': optimizer_h_adv.state_dict(), 'lr_scheduler_f': lr_scheduler_f.state_dict(), 'lr_scheduler_h': lr_scheduler_h.state_dict(), 'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if target_val_acc['all'] > best_acc: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_acc = target_val_acc['all'] print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(source_val_acc['all'], target_val_acc['all'], best_acc)) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) logger.close() def pretrain(train_source_iter, model, criterion, optimizer, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ":.2e") acc_s = AverageMeter("Acc (s)", ":3.2f") progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, acc_s], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, label_s, weight_s, meta_s = next(train_source_iter) x_s = x_s.to(device) label_s = label_s.to(device) weight_s = weight_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s = model(x_s) loss_s = criterion(y_s, label_s, weight_s) # compute gradient and do SGD step loss_s.backward() optimizer.step() # measure accuracy and record loss _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s.update(avg_acc_s, cnt_s) losses_s.update(loss_s, cnt_s) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) def train(train_source_iter, train_target_iter, model, criterion,regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ":.2e") losses_gf = AverageMeter('Loss (t, false)', ":.2e") losses_gt = AverageMeter('Loss (t, truth)', ":.2e") acc_s = AverageMeter("Acc (s)", ":3.2f") acc_t = AverageMeter("Acc (t)", ":3.2f") acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f") acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f") progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t, acc_s_adv, acc_t_adv], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, label_s, weight_s, meta_s = next(train_source_iter) x_t, label_t, weight_t, meta_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.to(device) weight_s = weight_s.to(device) x_t = x_t.to(device) label_t = label_t.to(device) weight_t = weight_t.to(device) # measure data loading time data_time.update(time.time() - end) # Step A train all networks to minimize loss on source domain optimizer_f.zero_grad() optimizer_h.zero_grad() optimizer_h_adv.zero_grad() y_s, y_s_adv = model(x_s) loss_s = criterion(y_s, label_s, weight_s) + \ args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min') loss_s.backward() optimizer_f.step() optimizer_h.step() optimizer_h_adv.step() # Step B train adv regressor to maximize regression disparity optimizer_h_adv.zero_grad() y_t, y_t_adv = model(x_t) loss_ground_false = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='max') loss_ground_false.backward() optimizer_h_adv.step() # Step C train feature extractor to minimize regression disparity optimizer_f.zero_grad() y_t, y_t_adv = model(x_t) loss_ground_truth = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='min') loss_ground_truth.backward() optimizer_f.step() # do update step model.step() lr_scheduler_f.step() lr_scheduler_h.step() lr_scheduler_h_adv.step() # measure accuracy and record loss _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s.update(avg_acc_s, cnt_s) _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(), label_t.detach().cpu().numpy()) acc_t.update(avg_acc_t, cnt_t) _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(y_s_adv.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s_adv.update(avg_acc_s_adv, cnt_s) _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(y_t_adv.detach().cpu().numpy(), label_t.detach().cpu().numpy()) acc_t_adv.update(avg_acc_t_adv, cnt_t) losses_s.update(loss_s, cnt_s) losses_gf.update(loss_ground_false, cnt_s) losses_gt.update(loss_ground_truth, cnt_s) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred".format(i)) visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label".format(i)) visualize(x_t[0], pred_t[0] * args.image_size / args.heatmap_size, "target_{}_pred".format(i)) visualize(x_t[0], meta_t['keypoint2d'][0], "target_{}_label".format(i)) visualize(x_s[0], pred_s_adv[0] * args.image_size / args.heatmap_size, "source_adv_{}_pred".format(i)) visualize(x_t[0], pred_t_adv[0] * args.image_size / args.heatmap_size, "target_adv_{}_pred".format(i)) def validate(val_loader, model, criterion, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.2e') acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f") progress = ProgressMeter( len(val_loader), [batch_time, losses, acc['all']], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (x, label, weight, meta) in enumerate(val_loader): x = x.to(device) label = label.to(device) weight = weight.to(device) # compute output y = model(x) loss = criterion(y, label, weight) # measure accuracy and record loss losses.update(loss.item(), x.size(0)) acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(), label.cpu().numpy()) group_acc = val_loader.dataset.group_accuracy(acc_per_points) acc.update(group_acc, x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x[0], pred[0] * args.image_size / args.heatmap_size, "val_{}_pred.jpg".format(i)) visualize(x[0], meta['keypoint2d'][0], "val_{}_label.jpg".format(i)) return acc.average() if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='RegDA for Keypoint Detection Domain Adaptation') # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3), help='scale range for the RandomResizeCrop augmentation') parser.add_argument('--rotation', type=int, default=180, help='rotation range of the RandomRotation augmentation') parser.add_argument('--image-size', type=int, default=256, help='input image size') parser.add_argument('--heatmap-size', type=int, default=64, help='output heatmap size') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet101', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: resnet101)') parser.add_argument("--pretrain", type=str, default=None, help="Where restore pretrained model parameters from.") parser.add_argument("--resume", type=str, default=None, help="where restore model parameters from.") parser.add_argument('--num-head-layers', type=int, default=2) parser.add_argument('--margin', type=float, default=4., help="margin gamma") parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0001, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--lr-gamma', default=0.0001, type=float) parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler') parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--pretrain_epochs', default=70, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--epochs', default=30, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='regda', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--debug', action="store_true", help='In the debug mode, save images and predictions') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/keypoint_detection/regda.sh ================================================ # Hands Dataset CUDA_VISIBLE_DEVICES=0 python regda.py data/RHD data/H3D_crop \ -s RenderedHandPose -t Hand3DStudio --seed 0 --debug --log logs/regda/rhd2h3d CUDA_VISIBLE_DEVICES=0 python regda.py data/FreiHand data/RHD \ -s FreiHand -t RenderedHandPose --seed 0 --debug --log logs/regda/freihand2rhd # Body Dataset CUDA_VISIBLE_DEVICES=0 python regda.py data/surreal_processed data/Human36M \ -s SURREAL -t Human36M --seed 0 --debug --rotation 30 --epochs 10 --log logs/regda/surreal2human36m CUDA_VISIBLE_DEVICES=0 python regda.py data/surreal_processed data/lsp \ -s SURREAL -t LSP --seed 0 --debug --rotation 30 --log logs/regda/surreal2lsp ================================================ FILE: examples/domain_adaptation/keypoint_detection/regda_fast.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import torch import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR, MultiStepLR from torch.utils.data import DataLoader from torchvision.transforms import Compose, ToPILImage sys.path.append('../../..') from tllib.alignment.regda import PoseResNet2d as RegDAPoseResNet, \ FastPseudoLabelGenerator2d, RegressionDisparity import tllib.vision.models as models from tllib.vision.models.keypoint_detection.pose_resnet import Upsampling, PoseResNet from tllib.vision.models.keypoint_detection.loss import JointsKLLoss import tllib.vision.datasets.keypoint_detection as datasets import tllib.vision.transforms.keypoint_detection as T from tllib.vision.transforms import Denormalize from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter, AverageMeterDict from tllib.utils.metric.keypoint_detection import accuracy from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_transform = T.Compose([ T.RandomRotation(args.rotation), T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale), T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25), T.GaussianBlur(), T.ToTensor(), normalize ]) val_transform = T.Compose([ T.Resize(args.image_size), T.ToTensor(), normalize ]) image_size = (args.image_size, args.image_size) heatmap_size = (args.heatmap_size, args.heatmap_size) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_source_dataset = source_dataset(root=args.source_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_source_loader = DataLoader(val_source_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform, image_size=image_size, heatmap_size=heatmap_size) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset(root=args.target_root, split='test', transforms=val_transform, image_size=image_size, heatmap_size=heatmap_size) val_target_loader = DataLoader(val_target_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True) print("Source train:", len(train_source_loader)) print("Target train:", len(train_target_loader)) print("Source test:", len(val_source_loader)) print("Target test:", len(val_target_loader)) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model backbone = models.__dict__[args.arch](pretrained=True) upsampling = Upsampling(backbone.out_features) num_keypoints = train_source_dataset.num_keypoints model = RegDAPoseResNet(backbone, upsampling, 256, num_keypoints, num_head_layers=args.num_head_layers, finetune=True).to(device) # define loss function criterion = JointsKLLoss() pseudo_label_generator = FastPseudoLabelGenerator2d() regression_disparity = RegressionDisparity(pseudo_label_generator, JointsKLLoss(epsilon=1e-7)) # define optimizer and lr scheduler optimizer_f = SGD([ {'params': backbone.parameters(), 'lr': 0.1}, {'params': upsampling.parameters(), 'lr': 0.1}, ], lr=0.1, momentum=args.momentum, weight_decay=args.wd, nesterov=True) optimizer_h = SGD(model.head.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True) optimizer_h_adv = SGD(model.head_adv.parameters(), lr=1., momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_decay_function = lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay) lr_scheduler_f = LambdaLR(optimizer_f, lr_decay_function) lr_scheduler_h = LambdaLR(optimizer_h, lr_decay_function) lr_scheduler_h_adv = LambdaLR(optimizer_h_adv, lr_decay_function) start_epoch = 0 if args.resume is None: if args.pretrain is None: # first pretrain the backbone and upsampling print("Pretraining the model on source domain.") args.pretrain = logger.get_checkpoint_path('pretrain') pretrained_model = PoseResNet(backbone, upsampling, 256, num_keypoints, True).to(device) optimizer = SGD(pretrained_model.get_parameters(lr=args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor) best_acc = 0 for epoch in range(args.pretrain_epochs): lr_scheduler.step() print(lr_scheduler.get_lr()) pretrain(train_source_iter, pretrained_model, criterion, optimizer, epoch, args) source_val_acc = validate(val_source_loader, pretrained_model, criterion, None, args) # remember best acc and save checkpoint if source_val_acc['all'] > best_acc: best_acc = source_val_acc['all'] torch.save( { 'model': pretrained_model.state_dict() }, args.pretrain ) print("Source: {} best: {}".format(source_val_acc['all'], best_acc)) # load from the pretrained checkpoint pretrained_dict = torch.load(args.pretrain, map_location='cpu')['model'] model_dict = model.state_dict() # remove keys from pretrained dict that doesn't appear in model dict pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model.load_state_dict(pretrained_dict, strict=False) else: # optionally resume from a checkpoint checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer_f.load_state_dict(checkpoint['optimizer_f']) optimizer_h.load_state_dict(checkpoint['optimizer_h']) optimizer_h_adv.load_state_dict(checkpoint['optimizer_h_adv']) lr_scheduler_f.load_state_dict(checkpoint['lr_scheduler_f']) lr_scheduler_h.load_state_dict(checkpoint['lr_scheduler_h']) lr_scheduler_h_adv.load_state_dict(checkpoint['lr_scheduler_h_adv']) start_epoch = checkpoint['epoch'] + 1 # define visualization function tensor_to_image = Compose([ Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ToPILImage() ]) def visualize(image, keypoint2d, name, heatmaps=None): """ Args: image (tensor): image in shape 3 x H x W keypoint2d (tensor): keypoints in shape K x 2 name: name of the saving image """ train_source_dataset.visualize(tensor_to_image(image), keypoint2d, logger.get_image_path("{}.jpg".format(name))) if args.phase == 'test': # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize, args) print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'], target_val_acc['all'])) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) return # start training best_acc = 0 print("Start regression domain adaptation.") for epoch in range(start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler_f.get_lr(), lr_scheduler_h.get_lr(), lr_scheduler_h_adv.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, criterion, regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch, visualize if args.debug else None, args) # evaluate on validation set source_val_acc = validate(val_source_loader, model, criterion, None, args) target_val_acc = validate(val_target_loader, model, criterion, visualize if args.debug else None, args) # remember best acc and save checkpoint torch.save( { 'model': model.state_dict(), 'optimizer_f': optimizer_f.state_dict(), 'optimizer_h': optimizer_h.state_dict(), 'optimizer_h_adv': optimizer_h_adv.state_dict(), 'lr_scheduler_f': lr_scheduler_f.state_dict(), 'lr_scheduler_h': lr_scheduler_h.state_dict(), 'lr_scheduler_h_adv': lr_scheduler_h_adv.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if target_val_acc['all'] > best_acc: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_acc = target_val_acc['all'] print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(source_val_acc['all'], target_val_acc['all'], best_acc)) for name, acc in target_val_acc.items(): print("{}: {:4.3f}".format(name, acc)) logger.close() def pretrain(train_source_iter, model, criterion, optimizer, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ":.2e") acc_s = AverageMeter("Acc (s)", ":3.2f") progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, acc_s], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, label_s, weight_s, meta_s = next(train_source_iter) x_s = x_s.to(device) label_s = label_s.to(device) weight_s = weight_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s = model(x_s) loss_s = criterion(y_s, label_s, weight_s) # compute gradient and do SGD step loss_s.backward() optimizer.step() # measure accuracy and record loss _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s.update(avg_acc_s, cnt_s) losses_s.update(loss_s, cnt_s) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) def train(train_source_iter, train_target_iter, model, criterion,regression_disparity, optimizer_f, optimizer_h, optimizer_h_adv, lr_scheduler_f, lr_scheduler_h, lr_scheduler_h_adv, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ":.2e") losses_gf = AverageMeter('Loss (t, false)', ":.2e") losses_gt = AverageMeter('Loss (t, truth)', ":.2e") acc_s = AverageMeter("Acc (s)", ":3.2f") acc_t = AverageMeter("Acc (t)", ":3.2f") acc_s_adv = AverageMeter("Acc (s, adv)", ":3.2f") acc_t_adv = AverageMeter("Acc (t, adv)", ":3.2f") progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, losses_gf, losses_gt, acc_s, acc_t, acc_s_adv, acc_t_adv], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, label_s, weight_s, meta_s = next(train_source_iter) x_t, label_t, weight_t, meta_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.to(device) weight_s = weight_s.to(device) x_t = x_t.to(device) label_t = label_t.to(device) weight_t = weight_t.to(device) # measure data loading time data_time.update(time.time() - end) # Step A train all networks to minimize loss on source domain optimizer_f.zero_grad() optimizer_h.zero_grad() optimizer_h_adv.zero_grad() y_s, y_s_adv = model(x_s) loss_s = criterion(y_s, label_s, weight_s) + \ args.margin * args.trade_off * regression_disparity(y_s, y_s_adv, weight_s, mode='min') loss_s.backward() optimizer_f.step() optimizer_h.step() optimizer_h_adv.step() # Step B train adv regressor to maximize regression disparity optimizer_h_adv.zero_grad() y_t, y_t_adv = model(x_t) loss_ground_false = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='max') loss_ground_false.backward() optimizer_h_adv.step() # Step C train feature extractor to minimize regression disparity optimizer_f.zero_grad() y_t, y_t_adv = model(x_t) loss_ground_truth = args.trade_off * regression_disparity(y_t, y_t_adv, weight_t, mode='min') loss_ground_truth.backward() optimizer_f.step() # do update step model.step() lr_scheduler_f.step() lr_scheduler_h.step() lr_scheduler_h_adv.step() # measure accuracy and record loss _, avg_acc_s, cnt_s, pred_s = accuracy(y_s.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s.update(avg_acc_s, cnt_s) _, avg_acc_t, cnt_t, pred_t = accuracy(y_t.detach().cpu().numpy(), label_t.detach().cpu().numpy()) acc_t.update(avg_acc_t, cnt_t) _, avg_acc_s_adv, cnt_s_adv, pred_s_adv = accuracy(y_s_adv.detach().cpu().numpy(), label_s.detach().cpu().numpy()) acc_s_adv.update(avg_acc_s_adv, cnt_s) _, avg_acc_t_adv, cnt_t_adv, pred_t_adv = accuracy(y_t_adv.detach().cpu().numpy(), label_t.detach().cpu().numpy()) acc_t_adv.update(avg_acc_t_adv, cnt_t) losses_s.update(loss_s, cnt_s) losses_gf.update(loss_ground_false, cnt_s) losses_gt.update(loss_ground_truth, cnt_s) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0] * args.image_size / args.heatmap_size, "source_{}_pred".format(i)) visualize(x_s[0], meta_s['keypoint2d'][0], "source_{}_label".format(i)) visualize(x_t[0], pred_t[0] * args.image_size / args.heatmap_size, "target_{}_pred".format(i)) visualize(x_t[0], meta_t['keypoint2d'][0], "target_{}_label".format(i)) visualize(x_s[0], pred_s_adv[0] * args.image_size / args.heatmap_size, "source_adv_{}_pred".format(i)) visualize(x_t[0], pred_t_adv[0] * args.image_size / args.heatmap_size, "target_adv_{}_pred".format(i)) def validate(val_loader, model, criterion, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.2e') acc = AverageMeterDict(val_loader.dataset.keypoints_group.keys(), ":3.2f") progress = ProgressMeter( len(val_loader), [batch_time, losses, acc['all']], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (x, label, weight, meta) in enumerate(val_loader): x = x.to(device) label = label.to(device) weight = weight.to(device) # compute output y = model(x) loss = criterion(y, label, weight) # measure accuracy and record loss losses.update(loss.item(), x.size(0)) acc_per_points, avg_acc, cnt, pred = accuracy(y.cpu().numpy(), label.cpu().numpy()) group_acc = val_loader.dataset.group_accuracy(acc_per_points) acc.update(group_acc, x.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x[0], pred[0] * args.image_size / args.heatmap_size, "val_{}_pred.jpg".format(i)) visualize(x[0], meta['keypoint2d'][0], "val_{}_label.jpg".format(i)) return acc.average() if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='RegDA (fast) for Keypoint Detection Domain Adaptation') # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.6, 1.3), help='scale range for the RandomResizeCrop augmentation') parser.add_argument('--rotation', type=int, default=180, help='rotation range of the RandomRotation augmentation') parser.add_argument('--image-size', type=int, default=256, help='input image size') parser.add_argument('--heatmap-size', type=int, default=64, help='output heatmap size') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet101', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: resnet101)') parser.add_argument("--pretrain", type=str, default=None, help="Where restore pretrained model parameters from.") parser.add_argument("--resume", type=str, default=None, help="where restore model parameters from.") parser.add_argument('--num-head-layers', type=int, default=2) parser.add_argument('--margin', type=float, default=4., help="margin gamma") parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0001, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--lr-gamma', default=0.0001, type=float) parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-step', default=[45, 60], type=tuple, help='parameter for lr scheduler') parser.add_argument('--lr-factor', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--pretrain_epochs', default=70, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--epochs', default=30, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='regda_fast', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--debug', action="store_true", help='In the debug mode, save images and predictions') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/keypoint_detection/regda_fast.sh ================================================ # regda_fast is provided by https://github.com/YouJiacheng?tab=repositories # On single V100(16G), overall adversarial training time is reduced by about 40%. # yet the PCK might drop 1% for each dataset. # Hands Dataset CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/RHD data/H3D_crop \ -s RenderedHandPose -t Hand3DStudio --seed 0 --debug --log logs/regda_fast/rhd2h3d CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/FreiHand data/RHD \ -s FreiHand -t RenderedHandPose --seed 0 --debug --log logs/regda_fast/freihand2rhd # Body Dataset CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/surreal_processed data/Human36M \ -s SURREAL -t Human36M --seed 0 --debug --rotation 30 --epochs 10 --log logs/regda_fast/surreal2human36m CUDA_VISIBLE_DEVICES=0 python regda_fast.py data/surreal_processed data/lsp \ -s SURREAL -t LSP --seed 0 --debug --rotation 30 --log logs/regda_fast/surreal2lsp ================================================ FILE: examples/domain_adaptation/object_detection/README.md ================================================ # Unsupervised Domain Adaptation for Object Detection ## Updates - *04/2022*: Provide CycleGAN translated datasets. ## Installation Our code is based on [Detectron latest(v0.6)](https://detectron2.readthedocs.io/en/latest/tutorials/install.html), please install it before usage. The following is an example based on PyTorch 1.9.0 with CUDA 11.1. For other versions, please refer to the official website of [PyTorch](https://pytorch.org/) and [Detectron](https://detectron2.readthedocs.io/en/latest/tutorials/install.html). ```shell # create environment conda create -n detection python=3.8.3 # activate environment conda activate detection # install pytorch pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html # install detectron python -m pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html # install other requirements pip install -r requirements.txt ``` ## Dataset Following datasets can be downloaded automatically: - [PASCAL_VOC 07+12](http://host.robots.ox.ac.uk/pascal/VOC/) - Clipart - WaterColor - Comic You need to prepare following datasets manually if you want to use them: #### Cityscapes, Foggy Cityscapes - Download Cityscapes and Foggy Cityscapes dataset from the [link](https://www.cityscapes-dataset.com/downloads/). Particularly, we use *leftImg8bit_trainvaltest.zip* for Cityscapes and *leftImg8bit_trainvaltest_foggy.zip* for Foggy Cityscapes. - Unzip them under the directory like ``` object_detction/datasets/cityscapes ├── gtFine ├── leftImg8bit ├── leftImg8bit_foggy └── ... ``` Then run ``` python prepare_cityscapes_to_voc.py ``` This will automatically generate dataset in `VOC` format. ``` object_detction/datasets/cityscapes_in_voc ├── Annotations ├── ImageSets └── JPEGImages object_detction/datasets/foggy_cityscapes_in_voc ├── Annotations ├── ImageSets └── JPEGImages ``` #### Sim10k - Download Sim10k dataset from the following links: [Sim10k](https://fcav.engin.umich.edu/projects/driving-in-the-matrix). Particularly, we use *repro_10k_images.tgz* , *repro_image_sets.tgz* and *repro_10k_annotations.tgz* for Sim10k. - Extract the training set from *repro_10k_images.tgz*, *repro_image_sets.tgz* and *repro_10k_annotations.tgz*, then rename directory `VOC2012/` to `sim10k/`. After preparation, there should exist following files: ``` object_detction/datasets/ ├── VOC2007 │ ├── Annotations │ ├──ImageSets │ └──JPEGImages ├── VOC2012 │ ├── Annotations │ ├── ImageSets │ └── JPEGImages ├── clipart │ ├── Annotations │ ├── ImageSets │ └── JPEGImages ├── watercolor │ ├── Annotations │ ├── ImageSets │ └── JPEGImages ├── comic │ ├── Annotations │ ├── ImageSets │ └── JPEGImages ├── cityscapes_in_voc │ ├── Annotations │ ├── ImageSets │ └── JPEGImages ├── foggy_cityscapes_in_voc │ ├── Annotations │ ├── ImageSets │ └── JPEGImages └── sim10k ├── Annotations ├── ImageSets └── JPEGImages ``` **Note**: The above is a tutorial for using standard datasets. To use your own datasets, you need to convert them into corresponding format. #### CycleGAN translated dataset The following command use CycleGAN to translate VOC (with directory `datasets/VOC2007` and `datasets/VOC2012`) to Clipart (with directory `datasets/VOC2007_to_clipart` and `datasets/VOC2012_to_clipart`). ``` mkdir datasets/VOC2007_to_clipart cp -r datasets/VOC2007/* datasets/VOC2007_to_clipart mkdir datasets/VOC2012_to_clipart cp -r datasets/VOC2012/* datasets/VOC2012_to_clipart CUDA_VISIBLE_DEVICES=0 python cycle_gan.py \ -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \ --translated-source datasets/VOC2007_to_clipart datasets/VOC2012_to_clipart \ --log logs/cyclegan_resnet9/translation/voc2clipart --netG resnet_9 ``` You can also download and use datasets that are translated by us. - PASCAL_VOC to Clipart [[07]](https://cloud.tsinghua.edu.cn/f/1b6b060d202145aea416/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/818dbd8e41a043fab7c3/?dl=1) (with directory `datasets/VOC2007_to_clipart` and `datasets/VOC2012_to_clipart`) - PASCAL_VOC to Comic [[07]](https://cloud.tsinghua.edu.cn/f/89382bba64514210a9f8/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/f90289137fd5465f806d/?dl=1) (with directory `datasets/VOC2007_to_comic` and `datasets/VOC2012_to_comic`) - PASCAL_VOC to WaterColor [[07]](https://cloud.tsinghua.edu.cn/f/8e982e9f21294b38be8a/?dl=1)+[[12]](https://cloud.tsinghua.edu.cn/f/b8235034cb4247ce809f/?dl=1) (with directory `datasets/VOC2007_to_watercolor` and `datasets/VOC2012_to_watercolor`) - Cityscapes to Foggy Cityscapes [[Part1]](https://cloud.tsinghua.edu.cn/f/09ceeb25a476481bae29/?dl=1) [[Part2]](https://cloud.tsinghua.edu.cn/f/51fb05d3ee614e7d87a0/?dl=1) [[Part3]](https://cloud.tsinghua.edu.cn/f/646415daf6b344c3a9e3/?dl=1) [[Part4]](https://cloud.tsinghua.edu.cn/f/008d5d3c54344f83b101/?dl=1) (with directory `datasets/cityscapes_to_foggy_cityscapes`). Note that you need to use ``cat`` to merge the downloaded files. - Sim10k to Cityscapes (Car) [[Download]](https://cloud.tsinghua.edu.cn/f/33ac656fcde34f758dcd/?dl=1) (with directory `datasets/sim10k2cityscapes_car`). ## Supported Methods Supported methods include: - [Cycle-Consistent Adversarial Networks (CycleGAN)](https://arxiv.org/pdf/1703.10593.pdf) - [Decoupled Adaptation for Cross-Domain Object Detection (D-adapt)](https://arxiv.org/abs/2110.02578) ## Experiment and Results The shell files give the script to reproduce the [benchmarks](/docs/dalib/benchmarks/object_detection.rst) with specified hyper-parameters. The basic training pipeline is as follows. The following command trains a Faster-RCNN detector on task VOC->Clipart, with only source (VOC) data. ``` CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \ --test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \ OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2clipart ``` Explanation of some arguments - `--config-file`: path to config file that specifies training hyper-parameters. - `-s`: a list that specifies source datasets, for each dataset you should pass in a `(name, path)` pair, in the above command, there are two source datasets **VOC2007** and **VOC2012**. - `-t`: a list that specifies target datasets, same format as above. - `--test`: a list that specifiers test datasets, same format as above. ### VOC->Clipart | | | AP | AP50 | AP75 | aeroplane | bicycle | bird | boat | bottle | bus | car | cat | chair | cow | diningtable | dog | horse | motorbike | person | pottedplant | sheep | sofa | train | tvmonitor | |-------------------------|----------|------|------|------|-----------|---------|------|------|--------|------|------|------|-------|------|-------------|------|-------|-----------|--------|-------------|-------|------|-------|-----------| | Faster RCNN (ResNet101) | Source | 14.9 | 29.3 | 12.6 | 29.6 | 38.0 | 24.7 | 21.7 | 31.9 | 48.0 | 30.8 | 15.9 | 32.0 | 19.2 | 18.2 | 12.1 | 28.2 | 48.8 | 38.3 | 34.6 | 3.8 | 22.5 | 43.7 | 44.0 | | | CycleGAN | 20.0 | 37.7 | 18.3 | 37.1 | 41.9 | 29.9 | 26.5 | 40.9 | 65.1 | 37.8 | 23.8 | 40.7 | 48.9 | 12.7 | 14.4 | 27.8 | 63.0 | 55.1 | 40.1 | 8.0 | 30.7 | 54.1 | 55.7 | | | D-adapt | 24.8 | 49.0 | 21.5 | 56.4 | 63.2 | 42.3 | 40.9 | 45.3 | 77.0 | 48.7 | 25.4 | 44.3 | 58.4 | 31.4 | 24.5 | 47.1 | 75.3 | 69.3 | 43.5 | 27.9 | 34.1 | 60.7 | 64.0 | | | | | | | | | | | | | | | | | | | | | | | | | | | | RetinaNet | Source | 18.3 | 32.2 | 17.6 | 34.2 | 42.4 | 27.0 | 21.6 | 36.8 | 48.4 | 35.9 | 16.4 | 38.9 | 22.6 | 27.0 | 15.1 | 27.1 | 46.7 | 42.1 | 36.2 | 8.3 | 29.5 | 42.1 | 46.2 | | | D-adapt | 25.1 | 46.3 | 23.9 | 47.4 | 65.0 | 33.1 | 37.5 | 56.8 | 61.2 | 55.1 | 27.3 | 45.5 | 51.8 | 29.1 | 29.6 | 38.0 | 74.5 | 66.7 | 46.0 | 24.2 | 29.3 | 54.2 | 53.8 | ### VOC->WaterColor | | AP | AP50 | AP75 | bicycle | bird | car | cat | dog | person | |-------------------------|------|------|------|---------|------|------|------|------|--------| | Faster RCNN (ResNet101) | 23.0 | 45.9 | 18.5 | 71.1 | 48.3 | 48.6 | 23.7 | 23.3 | 60.3 | | CycleGAN | 24.9 | 50.8 | 22.4 | 75.8 | 52.1 | 49.8 | 30.1 | 33.4 | 63.6 | | D-adapt | 28.5 | 57.5 | 23.6 | 77.4 | 54.0 | 52.8 | 43.9 | 48.1 | 68.9 | | Target | 23.8 | 51.3 | 17.4 | 48.5 | 54.7 | 41.3 | 36.2 | 52.6 | 74.6 | ### VOC->Comic | | AP | AP50 | AP75 | bicycle | bird | car | cat | dog | person | |:-----------------------:|:----:|:----:|:----:|:-------:|:----:|:----:|:----:|:----:|:------:| | Faster RCNN (ResNet101) | 13.0 | 25.5 | 11.4 | 33.0 | 15.8 | 28.9 | 16.8 | 19.6 | 39.0 | | CycleGAN | 16.9 | 34.6 | 14.2 | 28.1 | 25.7 | 37.7 | 28.0 | 33.8 | 54.1 | | D-adapt | 20.8 | 41.1 | 18.5 | 49.4 | 25.7 | 43.3 | 36.9 | 32.7 | 58.5 | | Target | 21.9 | 44.6 | 16.0 | 40.7 | 32.3 | 38.3 | 43.9 | 41.3 | 71.0 | ### Cityscapes->Foggy Cityscapes | | | AP | AP50 | AP75 | bicycle | bus | car | motorcycle | person | rider | train | truck | |:-----------------------:|:--------:|:----:|:----:|:----:|:-------:|:----:|:----:|:----------:|:------:|:-----:|:-----:|:-----:| | Faster RCNN (VGG16) | Source | 14.3 | 25.9 | 13.2 | 33.6 | 27.0 | 40.0 | 22.3 | 31.3 | 38.5 | 2.3 | 12.2 | | | CycleGAN | 22.5 | 41.6 | 20.7 | 46.5 | 41.5 | 62.0 | 33.8 | 45.0 | 54.5 | 21.7 | 27.7 | | | D-adapt | 19.4 | 38.1 | 17.5 | 42.0 | 36.8 | 58.1 | 32.2 | 43.1 | 51.8 | 14.6 | 26.3 | | | Target | 24.0 | 45.3 | 21.3 | 45.9 | 47.4 | 67.3 | 39.7 | 49.0 | 53.2 | 30.0 | 29.6 | | | | | | | | | | | | | | | | Faster RCNN (ResNet101) | Source | 18.8 | 33.3 | 19.0 | 36.1 | 34.5 | 43.8 | 24.0 | 36.3 | 39.9 | 29.1 | 22.8 | | | CycleGAN | 22.9 | 41.8 | 21.9 | 42.0 | 44.5 | 57.6 | 36.3 | 40.9 | 48.0 | 30.8 | 34.3 | | | D-adapt | 22.7 | 42.4 | 21.6 | 41.8 | 44.4 | 56.6 | 31.4 | 41.8 | 48.6 | 42.3 | 32.4 | | | Target | 25.5 | 45.3 | 24.3 | 41.9 | 53.2 | 63.4 | 36.1 | 42.6 | 47.9 | 42.4 | 35.3 | ### Sim10k->Cityscapes Car | | | AP | AP50 | AP75 | |:-----------------------:|:--------:|:----:|:----:|:----:| | Faster RCNN (VGG16) | Source | 24.8 | 43.4 | 23.6 | | | CycleGAN | 29.3 | 51.9 | 28.6 | | | D-adapt | 23.6 | 48.5 | 18.7 | | | Target | 24.8 | 43.4 | 23.6 | | | | | | | | Faster RCNN (ResNet101) | Source | 24.6 | 44.4 | 23.0 | | | CycleGAN | 26.5 | 47.4 | 24.0 | | | D-adapt | 27.4 | 51.9 | 25.7 | | | Target | 24.6 | 44.4 | 23.0 | ### Visualization We provide code for visualization in `visualize.py`. For example, suppose you have trained the source only model of task VOC->Clipart using provided scripts. The following code visualizes the prediction of the detector on Clipart. ```shell CUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \ --test Clipart datasets/clipart --save-path visualizations/source_only/voc2clipart \ MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth ``` Explanation of some arguments - `--test`: a list that specifiers test datasets for visualization. - `--save-path`: where to save visualization results. - `MODEL.WEIGHTS`: path to the model. ## TODO Support methods: SWDA, Global/Local Alignment ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{jiang2021decoupled, title = {Decoupled Adaptation for Cross-Domain Object Detection}, author = {Junguang Jiang and Baixu Chen and Jianmin Wang and Mingsheng Long}, booktitle = {ICLR}, year = {2022} } @inproceedings{CycleGAN, title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks}, author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, booktitle={ICCV}, year={2017} } ``` ================================================ FILE: examples/domain_adaptation/object_detection/config/faster_rcnn_R_101_C4_cityscapes.yaml ================================================ MODEL: META_ARCHITECTURE: "TLGeneralizedRCNN" WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" MASK_ON: False RESNETS: DEPTH: 101 ROI_HEADS: NAME: "TLRes5ROIHeads" NUM_CLASSES: 8 BATCH_SIZE_PER_IMAGE: 512 ANCHOR_GENERATOR: SIZES: [ [ 64, 128, 256, 512 ] ] RPN: PRE_NMS_TOPK_TEST: 6000 POST_NMS_TOPK_TEST: 1000 BATCH_SIZE_PER_IMAGE: 256 PROPOSAL_GENERATOR: NAME: "TLRPN" INPUT: MIN_SIZE_TRAIN: (512, 544, 576, 608, 640, 672, 704,) MIN_SIZE_TEST: 608 MAX_SIZE_TRAIN: 1166 DATASETS: TRAIN: ("cityscapes_trainval",) TEST: ("cityscapes_test",) SOLVER: STEPS: (12000,) MAX_ITER: 16000 # 16 epochs WARMUP_ITERS: 100 CHECKPOINT_PERIOD: 2000 IMS_PER_BATCH: 2 BASE_LR: 0.005 TEST: EVAL_PERIOD: 2000 VIS_PERIOD: 500 VERSION: 2 ================================================ FILE: examples/domain_adaptation/object_detection/config/faster_rcnn_R_101_C4_voc.yaml ================================================ MODEL: META_ARCHITECTURE: "TLGeneralizedRCNN" WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" MASK_ON: False RESNETS: DEPTH: 101 ROI_HEADS: NAME: "TLRes5ROIHeads" NUM_CLASSES: 20 BATCH_SIZE_PER_IMAGE: 256 ANCHOR_GENERATOR: SIZES: [ [ 64, 128, 256, 512 ] ] RPN: PRE_NMS_TOPK_TEST: 6000 POST_NMS_TOPK_TEST: 1000 BATCH_SIZE_PER_IMAGE: 128 PROPOSAL_GENERATOR: NAME: "TLRPN" INPUT: MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704,) MIN_SIZE_TEST: 608 MAX_SIZE_TRAIN: 1166 DATASETS: TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') TEST: ('voc_2007_test',) SOLVER: STEPS: (12000, ) MAX_ITER: 16000 # 16 epochs WARMUP_ITERS: 100 CHECKPOINT_PERIOD: 2000 IMS_PER_BATCH: 4 BASE_LR: 0.005 TEST: EVAL_PERIOD: 2000 VIS_PERIOD: 500 VERSION: 2 ================================================ FILE: examples/domain_adaptation/object_detection/config/faster_rcnn_vgg_16_cityscapes.yaml ================================================ MODEL: META_ARCHITECTURE: "TLGeneralizedRCNN" WEIGHTS: 'https://open-mmlab.oss-cn-beijing.aliyuncs.com/pretrain/vgg16_caffe-292e1171.pth' PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MASK_ON: False BACKBONE: NAME: "build_vgg_fpn_backbone" ROI_HEADS: IN_FEATURES: ["p3", "p4", "p5", "p6"] NAME: "TLStandardROIHeads" NUM_CLASSES: 8 ROI_BOX_HEAD: NAME: "FastRCNNConvFCHead" NUM_FC: 2 POOLER_RESOLUTION: 7 ANCHOR_GENERATOR: SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ] # One size for each in feature map ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ] # Three aspect RPN: IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"] PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level PRE_NMS_TOPK_TEST: 1000 # Per FPN level # Detectron1 uses 2000 proposals per-batch, # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. POST_NMS_TOPK_TRAIN: 1000 POST_NMS_TOPK_TEST: 1000 PROPOSAL_GENERATOR: NAME: "TLRPN" INPUT: FORMAT: "RGB" MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1280 MAX_SIZE_TRAIN: 1280 DATASETS: TRAIN: ("cityscapes_trainval",) TEST: ("cityscapes_test",) SOLVER: STEPS: (12000,) MAX_ITER: 16000 # 16 epochs WARMUP_ITERS: 100 CHECKPOINT_PERIOD: 2000 IMS_PER_BATCH: 8 BASE_LR: 0.01 TEST: EVAL_PERIOD: 2000 VIS_PERIOD: 500 VERSION: 2 ================================================ FILE: examples/domain_adaptation/object_detection/config/retinanet_R_101_FPN_voc.yaml ================================================ MODEL: META_ARCHITECTURE: "TLRetinaNet" WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" BACKBONE: NAME: "build_retinanet_resnet_fpn_backbone" MASK_ON: False RESNETS: DEPTH: 101 OUT_FEATURES: [ "res4", "res5" ] ANCHOR_GENERATOR: SIZES: !!python/object/apply:eval [ "[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [64, 128, 256, 512 ]]" ] RETINANET: NUM_CLASSES: 20 IN_FEATURES: ["p4", "p5", "p6", "p7"] IOU_THRESHOLDS: [ 0.4, 0.5 ] IOU_LABELS: [ 0, -1, 1 ] SMOOTH_L1_LOSS_BETA: 0.0 FPN: IN_FEATURES: ["res4", "res5"] INPUT: MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, ) MIN_SIZE_TEST: 608 MAX_SIZE_TRAIN: 1166 DATASETS: TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') TEST: ('voc_2007_test',) SOLVER: STEPS: (12000, ) MAX_ITER: 16000 # 16 epochs WARMUP_ITERS: 100 CHECKPOINT_PERIOD: 2000 IMS_PER_BATCH: 8 BASE_LR: 0.005 TEST: EVAL_PERIOD: 2000 VIS_PERIOD: 500 VERSION: 2 ================================================ FILE: examples/domain_adaptation/object_detection/cycle_gan.py ================================================ """ CycleGAN for VOC-format Object Detection Dataset You need to modify function build_dataset if you want to use your own dataset. @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import itertools import os import tqdm from typing import Optional, Callable, Tuple, Any, List from PIL import Image import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, ConcatDataset from torchvision.transforms import ToPILImage, Compose import torchvision.datasets as datasets from torchvision.datasets.folder import default_loader import torchvision.transforms as T sys.path.append('../../..') import tllib.translation.cyclegan as cyclegan from tllib.translation.cyclegan.util import ImagePool, set_requires_grad from tllib.vision.transforms import Denormalize from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(max(round(oh / base), 1) * base) w = int(max(round(ow / base), 1) * base) if h == oh and w == ow: return img return img.resize((w, h), method) class VOCImageFolder(datasets.VisionDataset): """A VOC-format Dataset class for image translation """ def __init__(self, root: str, phase='trainval', transform: Optional[Callable] = None, extension='.jpg'): super().__init__(root, transform=transform) data_list_file = os.path.join(root, "ImageSets/Main/{}.txt".format(phase)) self.samples = self.parse_data_file(data_list_file, extension) self.loader = default_loader self.data_list_file = data_list_file def __getitem__(self, index: int) -> Tuple[Any, str]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path = self.samples[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) return img, path def __len__(self) -> int: return len(self.samples) def parse_data_file(self, file_name: str, extension: str) -> List[str]: """Parse file to data list Args: file_name (str): The path of data file return (list): List of (image path, class_index) tuples """ with open(file_name, "r") as f: data_list = [] for line in f.readlines(): line = line.strip() if extension is None: path = line else: path = line + extension if not os.path.isabs(path): path = os.path.join(self.root, "JPEGImages", path) data_list.append((path)) return data_list def translate(self, transform: Callable, target_root: str, image_base=4): """ Translate an image and save it into a specified directory Args: transform (callable): a transform function that maps (image, label) pair from one domain to another domain target_root (str): the root directory to save images and labels """ os.makedirs(target_root, exist_ok=True) for path in tqdm.tqdm(self.samples): image = Image.open(path).convert('RGB') translated_path = path.replace(self.root, target_root) ow, oh = image.size image = make_power_2(image, image_base) translated_image = transform(image) translated_image = translated_image.resize((ow, oh)) os.makedirs(os.path.dirname(translated_path), exist_ok=True) translated_image.save(translated_path) def main(args): logger = CompleteLogger(args.log, args.phase) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = T.Compose([ T.RandomRotation(args.rotation), T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=args.resize_scale), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) train_source_dataset = build_dataset(args.source[::2], args.source[1::2], train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_target_dataset = build_dataset(args.target[::2], args.target[1::2], train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # define networks (both generators and discriminators) netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device) netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device) # create image buffer to store previously generated images fake_S_pool = ImagePool(args.pool_size) fake_T_pool = ImagePool(args.pool_size) # define optimizer and lr scheduler optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay) lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function) lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function) # optionally resume from a checkpoint if args.resume: print("Resume from", args.resume) checkpoint = torch.load(args.resume, map_location='cpu') netG_S2T.load_state_dict(checkpoint['netG_S2T']) netG_T2S.load_state_dict(checkpoint['netG_T2S']) netD_S.load_state_dict(checkpoint['netD_S']) netD_T.load_state_dict(checkpoint['netD_T']) optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D.load_state_dict(checkpoint['optimizer_D']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D']) args.start_epoch = checkpoint['epoch'] + 1 if args.phase == 'train': # define loss function criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss() criterion_cycle = nn.L1Loss() criterion_identity = nn.L1Loss() # define visualization function tensor_to_image = Compose([ Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ToPILImage() ]) def visualize(image, name): """ Args: image (tensor): image in shape 3 x H x W name: name of the saving image """ tensor_to_image(image).save(logger.get_image_path("{}.png".format(name))) # start training for epoch in range(args.start_epoch, args.epochs+args.epochs_decay): logger.set_epoch(epoch) print(lr_scheduler_G.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch, visualize, args) # update learning rates lr_scheduler_G.step() lr_scheduler_D.step() # save checkpoint torch.save( { 'netG_S2T': netG_S2T.state_dict(), 'netG_T2S': netG_T2S.state_dict(), 'netD_S': netD_S.state_dict(), 'netD_T': netD_T.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D': optimizer_D.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D': lr_scheduler_D.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path('latest') ) if args.translated_source is not None: transform = cyclegan.transform.Translation(netG_S2T, device) for dataset, translated_source in zip(train_source_dataset.datasets, args.translated_source): dataset.translate(transform, translated_source, image_base=args.image_base) if args.translated_target is not None: transform = cyclegan.transform.Translation(netG_T2S, device) for dataset, translated_target in zip(train_target_dataset.datasets, args.translated_target): dataset.translate(transform, translated_target, image_base=args.image_base) logger.close() def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_G_S2T = AverageMeter('G_S2T', ':3.2f') losses_G_T2S = AverageMeter('G_T2S', ':3.2f') losses_D_S = AverageMeter('D_S', ':3.2f') losses_D_T = AverageMeter('D_T', ':3.2f') losses_cycle_S = AverageMeter('cycle_S', ':3.2f') losses_cycle_T = AverageMeter('cycle_T', ':3.2f') losses_identity_S = AverageMeter('idt_S', ':3.2f') losses_identity_T = AverageMeter('idt_T', ':3.2f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T], prefix="Epoch: [{}]".format(epoch)) end = time.time() for i in range(args.iters_per_epoch): real_S, _ = next(train_source_iter) real_T, _ = next(train_target_iter) real_S = real_S.to(device) real_T = real_T.to(device) # measure data loading time data_time.update(time.time() - end) # Compute fake images and reconstruction images. fake_T = netG_S2T(real_S) rec_S = netG_T2S(fake_T) fake_S = netG_T2S(real_T) rec_T = netG_S2T(fake_S) # Optimizing generators # discriminators require no gradients set_requires_grad(netD_S, False) set_requires_grad(netD_T, False) optimizer_G.zero_grad() # GAN loss D_T(G_S2T(S)) loss_G_S2T = criterion_gan(netD_T(fake_T), real=True) # GAN loss D_S(G_T2S(B)) loss_G_T2S = criterion_gan(netD_S(fake_S), real=True) # Cycle loss || G_T2S(G_S2T(S)) - S|| loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle # Cycle loss || G_S2T(G_T2S(T)) - T|| loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle # Identity loss # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T|| identity_T = netG_S2T(real_T) loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S|| identity_S = netG_T2S(real_S) loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity # combined loss and calculate gradients loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T loss_G.backward() optimizer_G.step() # Optimize discriminator set_requires_grad(netD_S, True) set_requires_grad(netD_T, True) optimizer_D.zero_grad() # Calculate GAN loss for discriminator D_S fake_S_ = fake_S_pool.query(fake_S.detach()) loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False)) loss_D_S.backward() # Calculate GAN loss for discriminator D_T fake_T_ = fake_T_pool.query(fake_T.detach()) loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False)) loss_D_T.backward() optimizer_D.step() # measure elapsed time losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0)) losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0)) losses_D_S.update(loss_D_S.item(), real_S.size(0)) losses_D_T.update(loss_D_T.item(), real_S.size(0)) losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0)) losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0)) losses_identity_S.update(loss_identity_S.item(), real_S.size(0)) losses_identity_T.update(loss_identity_T.item(), real_S.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T], ["real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T", "identity_S", "identity_T"]): visualize(tensor[0], "{}_{}".format(i, name)) def build_dataset(dataset_names, dataset_roots, transform): """ Give a sequence of dataset class name and a sequence of dataset root directory, return a sequence of built datasets """ dataset_lists = [] for dataset_name, root in zip(dataset_names, dataset_roots): if dataset_name in ["WaterColor", "Comic"]: dataset = VOCImageFolder(root, phase='train', transform=transform) elif dataset_name in ["Cityscapes", "FoggyCityscapes"]: dataset = VOCImageFolder(root, phase="trainval", transform=transform, extension=".png") elif dataset_name in ["Sim10k"]: dataset = VOCImageFolder(root, phase="trainval10k", transform=transform) else: dataset = VOCImageFolder(root, phase="trainval", transform=transform) dataset_lists.append(dataset) return ConcatDataset(dataset_lists) if __name__ == '__main__': parser = argparse.ArgumentParser(description='CycleGAN for Segmentation') # dataset parameters parser.add_argument('-s', '--source', nargs='+', help='source domain(s)') parser.add_argument('-t', '--target', nargs='+', help='target domain(s)') parser.add_argument('--rotation', type=int, default=0, help='rotation range of the RandomRotation augmentation') parser.add_argument('--resize-ratio', nargs='+', type=float, default=(0.5, 1.0), help='the resize ratio for the random resize crop') parser.add_argument('--resize-scale', nargs='+', type=float, default=(3./4., 4./3.), help='the resize scale for the random resize crop') parser.add_argument('--train-size', nargs='+', type=int, default=(512, 512), help='the input and output image size during training') # model parameters parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--netD', type=str, default='patch', help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.') parser.add_argument('--netG', type=str, default='unet_256', help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]') parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') parser.add_argument("--resume", type=str, default=None, help="Where restore model parameters from.") parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss') parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss') # training parameters parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size (default: 1)') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--epochs-decay', type=int, default=20, help='number of epochs to linearly decay learning rate to zero') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int, help='Number of iterations per epoch') parser.add_argument('--pool-size', type=int, default=50, help='the size of image buffer that stores previously generated images') parser.add_argument('-p', '--print-freq', default=500, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='cyclegan', help="Where to save logs, checkpoints and debugging images.") # test parameters parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--test-input-size', nargs='+', type=int, default=(512, 512), help='the input image size during test') parser.add_argument('--translated-source', type=str, default=None, nargs='+', help="The root to put the translated source dataset") parser.add_argument('--translated-target', type=str, default=None, nargs='+', help="The root to put the translated target dataset") parser.add_argument('--image-base', default=4, type=int, help='the input image will be multiple of image-base before translated') args = parser.parse_args() print(args) main(args) ================================================ FILE: examples/domain_adaptation/object_detection/cycle_gan.sh ================================================ # VOC to Clipart mkdir datasets/VOC2007_to_clipart cp -r datasets/VOC2007/* datasets/VOC2007_to_clipart mkdir datasets/VOC2012_to_clipart cp -r datasets/VOC2012/* datasets/VOC2012_to_clipart CUDA_VISIBLE_DEVICES=0 python cycle_gan.py \ -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \ --translated-source datasets/VOC2007_to_clipart datasets/VOC2012_to_clipart \ --log logs/cyclegan_resnet9/translation/voc2clipart --netG resnet_9 CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 VOC2007 datasets/VOC2007_to_clipart VOC2012 datasets/VOC2012_to_clipart \ -t Clipart datasets/clipart \ --test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \ OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2clipart # VOC to Comic mkdir datasets/VOC2007_to_comic cp -r datasets/VOC2007/* datasets/VOC2007_to_comic mkdir datasets/VOC2012_to_comic cp -r datasets/VOC2012/* datasets/VOC2012_to_comic CUDA_VISIBLE_DEVICES=0 python cycle_gan.py \ -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Comic datasets/comic \ --translated-source datasets/VOC2007_to_comic datasets/VOC2012_to_comic \ --log logs/cyclegan_resnet9/translation/voc2comic --netG resnet_9 CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 VOC2007Partial datasets/VOC2007_to_comic VOC2012Partial datasets/VOC2012_to_comic \ -t Comic datasets/comic \ --test VOC2007Test datasets/VOC2007 ComicTest datasets/comic --finetune \ OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2comic MODEL.ROI_HEADS.NUM_CLASSES 6 # VOC to WaterColor mkdir datasets/VOC2007_to_watercolor cp -r datasets/VOC2007/* datasets/VOC2007_to_watercolor mkdir datasets/VOC2012_to_watercolor cp -r datasets/VOC2012/* datasets/VOC2012_to_watercolor CUDA_VISIBLE_DEVICES=0 python cycle_gan.py \ -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t WaterColor datasets/watercolor \ --translated-source datasets/VOC2007_to_watercolor datasets/VOC2012_to_watercolor \ --log logs/cyclegan_resnet9/translation/voc2watercolor --netG resnet_9 CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 VOC2007Partial datasets/VOC2007_to_watercolor VOC2012Partial datasets/VOC2012_to_watercolor \ -t WaterColor datasets/watercolor \ --test VOC2007Test datasets/VOC2007 WaterColorTest datasets/watercolor --finetune \ OUTPUT_DIR logs/cyclegan_resnet9/faster_rcnn_R_101_C4/voc2watercolor MODEL.ROI_HEADS.NUM_CLASSES 6 # Cityscapes to Foggy Cityscapes mkdir datasets/cityscapes_to_foggy_cityscapes cp -r datasets/cityscapes_in_voc/* datasets/cityscapes_to_foggy_cityscapes CUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s Cityscapes datasets/cityscapes_in_voc \ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ --translated-source datasets/cityscapes_to_foggy_cityscapes \ --log logs/cyclegan/translation/cityscapes2foggy # ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s Cityscapes datasets/cityscapes_in_voc/ Cityscapes datasets/cityscapes_to_foggy_cityscapes/ \ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ --test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \ OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/cityscapes2foggy # VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s Cityscapes datasets/cityscapes_in_voc/ Cityscapes datasets/cityscapes_to_foggy_cityscapes/ \ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ --test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \ OUTPUT_DIR logs/cyclegan/faster_rcnn_vgg_16/cityscapes2foggy # Sim10k to Cityscapes Car mkdir datasets/sim10k_to_cityscapes_car cp -r datasets/sim10k/* datasets/sim10k_to_cityscapes_car CUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s Sim10k datasets/sim10k -t Cityscapes datasets/cityscapes_in_voc \ --log logs/cyclegan/translation/sim10k2cityscapes_car --translated-source datasets/sim10k_to_cityscapes_car --image-base 256 # ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s Sim10kCar datasets/sim10k Sim10kCar datasets/sim10k_to_cityscapes_car -t CityscapesCar datasets/cityscapes_in_voc/ \ --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \ OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1 # VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s Sim10kCar datasets/sim10k Sim10kCar datasets/sim10k_to_cityscapes_car -t CityscapesCar datasets/cityscapes_in_voc/ \ --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \ OUTPUT_DIR logs/cyclegan/faster_rcnn_vgg_16/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1 # GTA5 to Cityscapes mkdir datasets/gta5_to_cityscapes cp -r datasets/synscapes_detection/* datasets/gta5_to_cityscapes CUDA_VISIBLE_DEVICES=0 python cycle_gan.py -s GTA5 datasets/synscapes_detection -t Cityscapes datasets/cityscapes_in_voc \ --log logs/cyclegan/translation/gta52cityscapes --translated-source datasets/gta5_to_cityscapes --image-base 256 # ResNet101 Based Faster RCNN: GTA5 -> Cityscapes CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s GTA5 datasets/synscapes_detection GTA5 datasets/gta5_to_cityscapes -t Cityscapes datasets/cityscapes_in_voc \ --test CityscapesTest datasets/cityscapes_in_voc/ --finetune \ OUTPUT_DIR logs/cyclegan/faster_rcnn_R_101_C4/gta52cityscapes ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/README.md ================================================ # Decoupled Adaptation for Cross-Domain Object Detection ## Installation Our code is based on - [Detectron latest(v0.6)](https://detectron2.readthedocs.io/en/latest/tutorials/install.html) - [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models) please install them before usage. ## Method Compared with previous cross-domain object detection methods, D-adapt decouples the adversarial adaptation from the training of detector.
Editor
The whole pipeline is as follows:
Editor
First, you need to run ``source_only.py`` to obtain pre-trained models. (See source_only.sh for scripts.) Then you need to run ``d_adapt.py`` to obtain adapted models. (See d_adapt.sh for scripts). When the domain discrepancy is large, you need to run ``d_adapt.py`` multiple times. For better readability, we implement the training of category adaptor in ``category_adaptation.py``, implement the training of the bounding box adaptor in``bbox_adaptation.py``, and implement the training of the detector and connect the above components in ``d_adapt.py``. This can facilitate you to modify and replace other adaptors. We provide independent training arguments for detector, category adaptor and bounding box adaptor. The arguments of latter two end with ``-c`` and ``-b`` respectively. ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{jiang2021decoupled, title = {Decoupled Adaptation for Cross-Domain Object Detection}, author = {Junguang Jiang and Baixu Chen and Jianmin Wang and Mingsheng Long}, booktitle = {ICLR}, year = {2022} } ``` ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/bbox_adaptation.py ================================================ """ Training a bounding box adaptor @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import os.path as osp import argparse from collections import deque import tqdm import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD, Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torchvision.transforms as T import torch.nn.functional as F from detectron2.modeling.box_regression import Box2BoxTransform from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.modules.regressor import Regressor from tllib.alignment.mdd import ImageRegressor, RegressionMarginDisparityDiscrepancy from tllib.alignment.d_adapt.proposal import ProposalDataset, PersistentProposalList, flatten, ExpandCrop import utils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class BoxTransform(nn.Module): def __init__(self): super(BoxTransform, self).__init__() BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0) self.box_transform = Box2BoxTransform(weights=BBOX_REG_WEIGHTS) def forward(self, pred_delta, gt_classes, proposal_boxes): """ Args: - pred_delta: predicted bounding box offset for each classes - gt_classes: ground truth classes - proposal_boxes: referenced bounding box Returns: predicted bounding box offset for ground truth classes and predicted bounding box """ gt_class_cols = 4 * gt_classes[:, None] + torch.arange(4, device=device) pred_delta = torch.gather(pred_delta, dim=1, index=gt_class_cols) pred_box = self.box_transform.apply_deltas(pred_delta, proposal_boxes) return pred_delta, pred_box def iou_between( boxes1: torch.Tensor, boxes2: torch.Tensor, eps: float = 1e-7, reduction: str = "none" ): """Intersections over Union between two boxes""" x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) assert (x2 >= x1).all(), "bad box: x1 larger than x2" assert (y2 >= y1).all(), "bad box: y1 larger than y2" # Intersection keypoints xkis1 = torch.max(x1, x1g) ykis1 = torch.max(y1, y1g) xkis2 = torch.min(x2, x2g) ykis2 = torch.min(y2, y2g) intsctk = torch.zeros_like(x1) mask = (ykis2 > ykis1) & (xkis2 > xkis1) intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk iouk = intsctk / (unionk + eps) if reduction == 'mean': return iouk.mean() elif reduction == 'sum': return iouk.sum() else: return iouk def clamp_single(box, w, h): x1, y1, x2, y2 = box x1 = x1.clamp(min=0, max=w) x2 = x2.clamp(min=0, max=w) y1 = y1.clamp(min=0, max=h) y2 = y2.clamp(min=0, max=h) return torch.tensor((x1, y1, x2, y2)) def clamp(boxes, widths, heights): """clamp (limit) the values in boxes within the widths and heights of the image.""" clamped_boxes = [] for box, w, h in zip(boxes, widths, heights): clamped_boxes.append(clamp_single(box, w, h)) return torch.stack(clamped_boxes, dim=0) class BoundingBoxAdaptor: def __init__(self, class_names, log, args): self.class_names = class_names for k, v in args._get_kwargs(): setattr(args, k.replace("_b", ""), v) self.args = args print(self.args) self.logger = CompleteLogger(log) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) num_classes = len(class_names) bottleneck_dim = args.bottleneck_dim bottleneck = nn.Sequential( nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(bottleneck_dim), nn.ReLU(), ) head = nn.Sequential( nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(bottleneck_dim), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(bottleneck_dim, num_classes * 4), ) for layer in head: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, 0, 0.01) nn.init.constant_(layer.bias, 0) adv_head = nn.Sequential( nn.Conv2d(bottleneck_dim, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(bottleneck_dim), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(bottleneck_dim, num_classes * 4), ) for layer in adv_head: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, 0, 0.01) nn.init.constant_(layer.bias, 0) self.model = ImageRegressor( backbone, num_classes * 4, bottleneck=bottleneck, head=head, adv_head=adv_head ).to(device) self.box_transform = BoxTransform() def load_checkpoint(self, path=None): if path is None: path = self.logger.get_checkpoint_path('latest') if osp.exists(path): checkpoint = torch.load(path, map_location='cpu') self.model.load_state_dict(checkpoint) return True else: return False def prepare_training_data(self, proposal_list: PersistentProposalList, labeled=True): if not labeled: # remove (predicted) background proposals filtered_proposals_list = [] for proposals in proposal_list: keep_indices = (0 <= proposals.pred_classes) & (proposals.pred_classes < len(self.class_names)) filtered_proposals_list.append(proposals[keep_indices]) else: # remove proposals with low IoU filtered_proposals_list = [] for proposals in proposal_list: keep_indices = proposals.gt_ious > 0.3 filtered_proposals_list.append(proposals[keep_indices]) filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_train) normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = T.Compose([ T.Resize((self.args.resize_size, self.args.resize_size)), # T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3), # T.RandomGrayscale(), T.ToTensor(), normalize ]) dataset = ProposalDataset(filtered_proposals_list, transform, crop_func=ExpandCrop(self.args.expand)) dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.workers, drop_last=True) return dataloader def prepare_validation_data(self, proposal_list: PersistentProposalList): normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = T.Compose([ T.Resize((self.args.resize_size, self.args.resize_size)), T.ToTensor(), normalize ]) # remove (predicted) background proposals filtered_proposals_list = [] for proposals in proposal_list: # keep_indices = (0 <= proposals.gt_classes) & (proposals.gt_classes < len(self.class_names)) keep_indices = (0 <= proposals.pred_classes) & (proposals.pred_classes < len(self.class_names)) filtered_proposals_list.append(proposals[keep_indices]) filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_val) dataset = ProposalDataset(filtered_proposals_list, transform, crop_func=ExpandCrop(self.args.expand)) dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=self.args.workers, drop_last=False) return dataloader def prepare_test_data(self, proposal_list: PersistentProposalList): normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = T.Compose([ T.Resize((self.args.resize_size, self.args.resize_size)), T.ToTensor(), normalize ]) dataset = ProposalDataset(proposal_list, transform, crop_func=ExpandCrop(self.args.expand)) dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=self.args.workers, drop_last=False) return dataloader def predict(self, data_loader): # switch to evaluate mode self.model.eval() predictions = deque() with torch.no_grad(): for images, labels in tqdm.tqdm(data_loader): images = images.to(device) pred_classes = labels['pred_classes'].to(device) pred_boxes = labels['pred_boxes'].to(device).float() # compute output pred_deltas = self.model(images) _, pred_boxes = self.box_transform(pred_deltas, pred_classes, pred_boxes) pred_boxes = clamp(pred_boxes.cpu(), labels['width'], labels['height']) pred_boxes = pred_boxes.numpy().tolist() for p in pred_boxes: predictions.append(p) return predictions def validate_baseline(self, val_loader): """call this function if you have labeled data for validation""" ious = AverageMeter("IoU", ":.4e") print("Calculate baseline IoU:") for _, labels in tqdm.tqdm(val_loader): gt_boxes = labels['gt_boxes'] pred_boxes = labels['pred_boxes'] ious.update(iou_between(pred_boxes, gt_boxes).mean().item(), gt_boxes.size(0)) print(' * Baseline IoU {:.3f}'.format(ious.avg)) return ious.avg @staticmethod def validate(val_loader, model, box_transform, args) -> float: """call this function if you have labeled data for validation""" batch_time = AverageMeter('Time', ':6.3f') ious = AverageMeter("IoU", ":.4e") progress = ProgressMeter( len(val_loader), [batch_time, ious], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, labels) in enumerate(val_loader): images = images.to(device) pred_classes = labels['pred_classes'].to(device) gt_boxes = labels['gt_boxes'].to(device).float() pred_boxes = labels['pred_boxes'].to(device).float() # compute output pred_deltas = model(images) _, pred_boxes = box_transform(pred_deltas, pred_classes, pred_boxes) pred_boxes = clamp(pred_boxes.cpu(), labels['width'], labels['height']) ious.update(iou_between(pred_boxes, gt_boxes.cpu()).mean().item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * IoU {:.3f}'.format(ious.avg)) return ious.avg def fit(self, data_loader_source, data_loader_target, data_loader_validation=None): """When no labels exists on target domain, please set data_loader_validation=None""" args = self.args print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True iter_source = ForeverDataIterator(data_loader_source) iter_target = ForeverDataIterator(data_loader_target) best_iou = 0. box_transform = self.box_transform # first pre-train on the source domain model = Regressor( self.model.backbone, len(self.class_names) * 4, bottleneck=nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() ), head=nn.Linear(self.model.backbone.out_features, len(self.class_names) * 4), bottleneck_dim=self.model.backbone.out_features ).to(device) optimizer = Adam(model.get_parameters(), args.pretrain_lr, weight_decay=args.pretrain_weight_decay) lr_scheduler = LambdaLR(optimizer, lambda x: args.pretrain_lr * (1. + args.pretrain_lr_gamma * float(x)) ** (-args.pretrain_lr_decay)) for epoch in range(args.pretrain_epochs): print("lr:", lr_scheduler.get_last_lr()[0]) batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') ious = AverageMeter("IoU", ":.4e") progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, ious], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(iter_source) x_s = x_s.to(device) # bounding box offsets delta_s = box_transform.box_transform.get_deltas(labels_s['pred_boxes'], labels_s['gt_boxes']).to(device).float() pred_boxes_s = labels_s['pred_boxes'].to(device).float() gt_classes_s = labels_s['gt_fg_classes'].to(device) gt_boxes_s = labels_s['gt_boxes'].to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output pred_delta_s, _ = model(x_s) pred_delta_s, pred_boxes_s = box_transform(pred_delta_s, gt_classes_s, pred_boxes_s) reg_loss = F.smooth_l1_loss(pred_delta_s, delta_s) loss = reg_loss losses.update(loss.item(), x_s.size(0)) ious.update(iou_between(pred_boxes_s.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) # evaluate on validation set if data_loader_validation is not None: iou = self.validate(data_loader_validation, model, box_transform, args) best_iou = max(iou, best_iou) # training on both domains model = self.model optimizer = SGD(model.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') ious = AverageMeter("IoU", ":.4e") ious_t = AverageMeter("IoU (t)", ":.4e") ious_s_adv = AverageMeter("IoU (s, adv)", ":.4e") ious_t_adv = AverageMeter("IoU (t, adv)", ":.4e") trans_losses = AverageMeter('Trans Loss', ':3.2f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, ious, ious_t, ious_s_adv, ious_t_adv], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() mdd = RegressionMarginDisparityDiscrepancy(args.margin).to(device) end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(iter_source) x_t, labels_t = next(iter_target) x_s = x_s.to(device) x_t = x_t.to(device) # bounding box offsets delta_s = box_transform.box_transform.get_deltas(labels_s['pred_boxes'], labels_s['gt_boxes']).to(device).float() pred_boxes_s = labels_s['pred_boxes'].to(device).float() gt_classes_s = labels_s['gt_fg_classes'].to(device) gt_boxes_s = labels_s['gt_boxes'].to(device).float() pred_boxes_t = labels_t['pred_boxes'].to(device).float() gt_classes_t = labels_t['pred_classes'].to(device) gt_boxes_t = labels_t['gt_boxes'].to(device).float() # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat([x_s, x_t], dim=0) outputs, outputs_adv = model(x) pred_delta_s, pred_delta_t = outputs.chunk(2, dim=0) pred_delta_s_adv, pred_delta_t_adv = outputs_adv.chunk(2, dim=0) pred_delta_s, pred_boxes_s = box_transform(pred_delta_s, gt_classes_s, pred_boxes_s) pred_delta_t, pred_boxes_t = box_transform(pred_delta_t, gt_classes_t, pred_boxes_t) pred_delta_s_adv, pred_boxes_s_adv = box_transform(pred_delta_s_adv, gt_classes_s, pred_boxes_s) pred_delta_t_adv, pred_boxes_t_adv = box_transform(pred_delta_t_adv, gt_classes_t, pred_boxes_t) reg_loss = F.smooth_l1_loss(pred_delta_s, delta_s) # compute margin disparity discrepancy between domains transfer_loss = mdd(pred_delta_s, pred_delta_s_adv, pred_delta_t, pred_delta_t_adv) # for adversarial classifier, minimize negative mdd is equal to maximize mdd loss = reg_loss - transfer_loss * args.trade_off model.step() losses.update(loss.item(), x_s.size(0)) ious.update(iou_between(pred_boxes_s.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0)) ious_t.update(iou_between(pred_boxes_t.cpu(), gt_boxes_t.cpu()).mean().item(), x_s.size(0)) ious_s_adv.update(iou_between(pred_boxes_s_adv.cpu(), gt_boxes_s.cpu()).mean().item(), x_s.size(0)) ious_t_adv.update(iou_between(pred_boxes_t_adv.cpu(), gt_boxes_t.cpu()).mean().item(), x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) # evaluate on validation set if data_loader_validation is not None: iou = self.validate(data_loader_validation, model, box_transform, args) best_iou = max(iou, best_iou) # save checkpoint torch.save(model.state_dict(), self.logger.get_checkpoint_path('latest')) print("best_iou = {:3.1f}".format(best_iou)) self.logger.logger.flush() @staticmethod def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(add_help=False) # dataset parameters parser.add_argument('--resize-size-b', type=int, default=224, help='the image size after resizing') parser.add_argument('--max-train-b', type=int, default=10) parser.add_argument('--max-val-b', type=int, default=10) parser.add_argument('--expand-b', type=float, default=2., help='The expanding ratio between the input of the bounding box adaptor' '(the crops of objects) and the the original predicted box.') # model parameters parser.add_argument('--arch-b', metavar='ARCH', default='resnet101', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet101)') parser.add_argument('--bottleneck-dim-b', default=1024, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool-b', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch-b', action='store_true', help='whether train from scratch.') parser.add_argument('--margin', type=float, default=4., help="margin hyper-parameter") parser.add_argument('--trade-off', default=0.1, type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('--batch-size-b', default=32, type=int, metavar='N', help='mini-batch size (default: 64)') parser.add_argument('--lr-b', default=0.004, type=float, metavar='LR', help='initial learning rate') parser.add_argument('--lr-gamma-b', default=0.0002, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-b', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--weight-decay-b', default=5e-4, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--workers-b', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs-b', default=2, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--pretrain-lr-b', default=0.001, type=float, metavar='LR', help='initial learning rate') parser.add_argument('--pretrain-lr-gamma-b', default=0.0002, type=float, help='parameter for lr scheduler') parser.add_argument('--pretrain-lr-decay-b', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--pretrain-weight-decay-b', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)') parser.add_argument('--pretrain-epochs-b', default=10, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--iters-per-epoch-b', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('--print-freq-b', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed-b', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log-b", type=str, default='box', help="Where to save logs, checkpoints and debugging images.") return parser ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/category_adaptation.py ================================================ """ Training a category adaptor @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import os.path as osp from collections import deque import tqdm from typing import List import torch from torch import Tensor import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torchvision.transforms as T import torch.nn.functional as F sys.path.append('../../../..') from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier from tllib.alignment.d_adapt.proposal import ProposalDataset, flatten, Proposal from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy, ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.vision.transforms import ResizeImage import utils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class ConfidenceBasedDataSelector: """Select data point based on confidence""" def __init__(self, confidence_ratio=0.1, category_names=()): self.confidence_ratio = confidence_ratio self.categories = [] self.scores = [] self.category_names = category_names self.per_category_thresholds = None def extend(self, categories, scores): self.categories.extend(categories) self.scores.extend(scores) def calculate(self): per_category_scores = {c: [] for c in self.category_names} for c, s in zip(self.categories, self.scores): per_category_scores[c].append(s) per_category_thresholds = {} print(per_category_scores.keys()) for c, s in per_category_scores.items(): s.sort(reverse=True) print(c, len(s), int(self.confidence_ratio * len(s))) per_category_thresholds[c] = s[int(self.confidence_ratio * len(s))] if len(s) else 1. print('----------------------------------------------------') print("confidence threshold for each category:") for c in self.category_names: print('\t', c, round(per_category_thresholds[c], 3)) print('----------------------------------------------------') self.per_category_thresholds = per_category_thresholds def whether_select(self, categories, scores): assert self.per_category_thresholds is not None, "please call calculate before selection!" return [s > self.per_category_thresholds[c] for c, s in zip(categories, scores)] class RobustCrossEntropyLoss(nn.CrossEntropyLoss): """Cross-entropy that's robust to label noise""" def __init__(self, *args, offset=0.1, **kwargs): self.offset = offset super(RobustCrossEntropyLoss, self).__init__(*args, **kwargs) def forward(self, input: Tensor, target: Tensor) -> Tensor: return F.cross_entropy(torch.clamp(input + self.offset, max=1.), target, weight=self.weight, ignore_index=self.ignore_index, reduction='sum') / input.shape[0] class CategoryAdaptor: def __init__(self, class_names, log, args): self.class_names = class_names for k, v in args._get_kwargs(): setattr(args, k.rstrip("_c"), v) self.args = args print(self.args) self.logger = CompleteLogger(log) self.selector = ConfidenceBasedDataSelector(self.args.confidence_ratio, range(len(self.class_names) + 1)) # create model print("=> using model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None num_classes = len(self.class_names) + 1 self.model = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch).to(device) def load_checkpoint(self): if osp.exists(self.logger.get_checkpoint_path('latest')): checkpoint = torch.load(self.logger.get_checkpoint_path('latest'), map_location='cpu') self.model.load_state_dict(checkpoint) return True else: return False def prepare_training_data(self, proposal_list: List[Proposal], labeled=True): if not labeled: # remove proposals with confidence score between (ignored_scores[0], ignored_scores[1]) filtered_proposals_list = [] assert len(self.args.ignored_scores) == 2 and self.args.ignored_scores[0] <= self.args.ignored_scores[1], \ "Please provide a range for ignored_scores!" for proposals in proposal_list: keep_indices = ~((self.args.ignored_scores[0] < proposals.pred_scores) & (proposals.pred_scores < self.args.ignored_scores[1])) filtered_proposals_list.append(proposals[keep_indices]) # calculate confidence threshold for each cateogry on the target domain for proposals in filtered_proposals_list: self.selector.extend(proposals.pred_classes.tolist(), proposals.pred_scores.tolist()) self.selector.calculate() else: # remove proposals with ignored classes or ious between (ignored_ious[0], ignored_ious[1]) filtered_proposals_list = [] for proposals in proposal_list: keep_indices = (proposals.gt_classes != -1) & \ ~((self.args.ignored_ious[0] < proposals.gt_ious) & (proposals.gt_ious < self.args.ignored_ious[1])) filtered_proposals_list.append(proposals[keep_indices]) filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_train) normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = T.Compose([ ResizeImage(self.args.resize_size), T.RandomHorizontalFlip(), T.ColorJitter(brightness=0.7, contrast=0.7, saturation=0.7, hue=0.5), T.RandomGrayscale(), T.ToTensor(), normalize ]) dataset = ProposalDataset(filtered_proposals_list, transform) dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=True, num_workers=self.args.workers, drop_last=True) return dataloader def prepare_validation_data(self, proposal_list: List[Proposal]): """call this function if you have labeled data for validation""" normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = T.Compose([ ResizeImage(self.args.resize_size), T.ToTensor(), normalize ]) # remove proposals with ignored classes filtered_proposals_list = [] for proposals in proposal_list: keep_indices = proposals.gt_classes != -1 filtered_proposals_list.append(proposals[keep_indices]) filtered_proposals_list = flatten(filtered_proposals_list, self.args.max_val) dataset = ProposalDataset(filtered_proposals_list, transform) dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=self.args.workers, drop_last=False) return dataloader def prepare_test_data(self, proposal_list: List[Proposal]): normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transform = T.Compose([ ResizeImage(self.args.resize_size), T.ToTensor(), normalize ]) dataset = ProposalDataset(proposal_list, transform) dataloader = DataLoader(dataset, batch_size=self.args.batch_size, shuffle=False, num_workers=self.args.workers, drop_last=False) return dataloader def fit(self, data_loader_source, data_loader_target, data_loader_validation=None): """When no labels exists on target domain, please set data_loader_validation=None""" args = self.args if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True iter_source = ForeverDataIterator(data_loader_source) iter_target = ForeverDataIterator(data_loader_target) model = self.model feature_dim = model.features_dim num_classes = len(self.class_names) + 1 if args.randomized: domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024).to(device) else: domain_discri = DomainDiscriminator(feature_dim * num_classes, hidden_size=1024).to(device) all_parameters = model.get_parameters() + domain_discri.get_parameters() # define optimizer and lr scheduler optimizer = SGD(all_parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv = ConditionalDomainAdversarialLoss( domain_discri, entropy_conditioning=args.entropy, num_classes=num_classes, features_dim=feature_dim, randomized=args.randomized, randomized_dim=args.randomized_dim ).to(device) # start training best_acc1 = 0. for epoch in range(args.epochs): print("lr:", lr_scheduler.get_last_lr()[0]) # train for one epoch batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') losses_t = AverageMeter('Loss(t)', ':3.2f') trans_losses = AverageMeter('Trans Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, losses_t, trans_losses, cls_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(iter_source) x_t, labels_t = next(iter_target) # assign pseudo labels for target-domain proposals with extremely high confidence selected = torch.tensor( self.selector.whether_select( labels_t['pred_classes'].numpy().tolist(), labels_t['pred_scores'].numpy().tolist() ) ) pseudo_classes_t = selected * labels_t['pred_classes'] + (~selected) * -1 pseudo_classes_t = pseudo_classes_t.to(device) x_s = x_s.to(device) x_t = x_t.to(device) gt_classes_s = labels_s['gt_classes'].to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, gt_classes_s, ignore_index=-1) cls_loss_t = RobustCrossEntropyLoss(ignore_index=-1, offset=args.epsilon)(y_t, pseudo_classes_t) transfer_loss = domain_adv(y_s, f_s, y_t, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off + cls_loss_t cls_acc = accuracy(y_s, gt_classes_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc, x_s.size(0)) domain_accs.update(domain_acc, x_s.size(0)) trans_losses.update(transfer_loss.item(), x_s.size(0)) losses_t.update(cls_loss_t.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) # evaluate on validation set if data_loader_validation is not None: acc1 = self.validate(data_loader_validation, model, self.class_names, args) best_acc1 = max(acc1, best_acc1) # save checkpoint torch.save(model.state_dict(), self.logger.get_checkpoint_path('latest')) print("best_acc1 = {:3.1f}".format(best_acc1)) domain_adv.to(torch.device("cpu")) self.logger.logger.flush() def predict(self, data_loader): # switch to evaluate mode self.model.eval() predictions = deque() with torch.no_grad(): for images, _ in tqdm.tqdm(data_loader): images = images.to(device) # compute output output = self.model(images) prediction = output.argmax(-1).cpu().numpy().tolist() for p in prediction: predictions.append(p) return predictions @staticmethod def validate(val_loader, model, class_names, args) -> float: batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1], prefix='Test: ') # switch to evaluate mode model.eval() confmat = ConfusionMatrix(len(class_names)+1) with torch.no_grad(): end = time.time() for i, (images, labels) in enumerate(val_loader): images = images.to(device) gt_classes = labels['gt_classes'].to(device) # compute output output = model(images) loss = F.cross_entropy(output, gt_classes) # measure accuracy and record loss acc1, = accuracy(output, gt_classes, topk=(1,)) confmat.update(gt_classes, output.argmax(1)) losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) print(confmat.format(class_names+["bg"])) return top1.avg @staticmethod def get_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(add_help=False) # dataset parameters parser.add_argument('--resize-size-c', type=int, default=112, help='the image size after resizing') parser.add_argument('--ignored-scores-c', type=float, nargs='+', default=[0.05, 0.3]) parser.add_argument('--max-train-c', type=int, default=10) parser.add_argument('--max-val-c', type=int, default=2) parser.add_argument('--ignored-ious-c', type=float, nargs='+', default=(0.4, 0.5), help='the iou threshold for ignored boxes') # model parameters parser.add_argument('--arch-c', metavar='ARCH', default='resnet101', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet101)') parser.add_argument('--bottleneck-dim-c', default=1024, type=int, help='Dimension of bottleneck') parser.add_argument('--no-pool-c', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch-c', action='store_true', help='whether train from scratch.') parser.add_argument('--randomized-c', action='store_true', help='using randomized multi-linear-map (default: False)') parser.add_argument('--randomized-dim-c', default=1024, type=int, help='randomized dimension when using randomized multi-linear-map (default: 1024)') parser.add_argument('--entropy-c', default=False, action='store_true', help='use entropy conditioning') parser.add_argument('--trade-off-c', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') parser.add_argument('--confidence-ratio-c', default=0.0, type=float) parser.add_argument('--epsilon-c', default=0.01, type=float, help='epsilon hyper-parameter in Robust Cross Entropy') # training parameters parser.add_argument('--batch-size-c', default=64, type=int, metavar='N', help='mini-batch size (default: 64)') parser.add_argument('--learning-rate-c', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma-c', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-c', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum-c', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay-c', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('--workers-c', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs-c', default=10, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--iters-per-epoch-c', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('--print-freq-c', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed-c', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log-c", type=str, default='cdan', help="Where to save logs, checkpoints and debugging images.") return parser ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_R_101_C4_cityscapes.yaml ================================================ MODEL: META_ARCHITECTURE: "DecoupledGeneralizedRCNN" WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" MASK_ON: False RESNETS: DEPTH: 101 ROI_HEADS: NAME: "DecoupledRes5ROIHeads" NUM_CLASSES: 8 BATCH_SIZE_PER_IMAGE: 512 ANCHOR_GENERATOR: SIZES: [ [ 64, 128, 256, 512 ] ] RPN: PRE_NMS_TOPK_TEST: 6000 POST_NMS_TOPK_TEST: 1000 BATCH_SIZE_PER_IMAGE: 256 PROPOSAL_GENERATOR: NAME: "TLRPN" INPUT: MIN_SIZE_TRAIN: (512, 544, 576, 608, 640, 672, 704,) MIN_SIZE_TEST: 800 MAX_SIZE_TRAIN: 1166 DATASETS: TRAIN: ("cityscapes_trainval",) TEST: ("cityscapes_test",) SOLVER: STEPS: (3999, ) MAX_ITER: 4000 # 4 epochs WARMUP_ITERS: 100 CHECKPOINT_PERIOD: 1000 IMS_PER_BATCH: 2 BASE_LR: 0.005 LR_SCHEDULER_NAME: "ExponentialLR" GAMMA: 0.1 TEST: EVAL_PERIOD: 500 VIS_PERIOD: 20 VERSION: 2 ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_R_101_C4_voc.yaml ================================================ MODEL: META_ARCHITECTURE: "DecoupledGeneralizedRCNN" WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" MASK_ON: False RESNETS: DEPTH: 101 ROI_HEADS: NAME: "DecoupledRes5ROIHeads" NUM_CLASSES: 20 BATCH_SIZE_PER_IMAGE: 256 ANCHOR_GENERATOR: SIZES: [ [ 64, 128, 256, 512 ] ] RPN: PRE_NMS_TOPK_TEST: 6000 POST_NMS_TOPK_TEST: 1000 BATCH_SIZE_PER_IMAGE: 128 PROPOSAL_GENERATOR: NAME: "TLRPN" INPUT: MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704,) MIN_SIZE_TEST: 608 MAX_SIZE_TRAIN: 1166 DATASETS: TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') TEST: ('voc_2007_test',) SOLVER: STEPS: (3999, ) MAX_ITER: 4000 # 16 epochs WARMUP_ITERS: 100 CHECKPOINT_PERIOD: 1000 IMS_PER_BATCH: 4 BASE_LR: 0.00025 LR_SCHEDULER_NAME: "ExponentialLR" GAMMA: 0.1 TEST: EVAL_PERIOD: 500 VIS_PERIOD: 20 VERSION: 2 ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/config/faster_rcnn_vgg_16_cityscapes.yaml ================================================ MODEL: META_ARCHITECTURE: "DecoupledGeneralizedRCNN" WEIGHTS: 'https://open-mmlab.oss-cn-beijing.aliyuncs.com/pretrain/vgg16_caffe-292e1171.pth' PIXEL_MEAN: [123.675, 116.280, 103.530] PIXEL_STD: [58.395, 57.120, 57.375] MASK_ON: False BACKBONE: NAME: "build_vgg_fpn_backbone" ROI_HEADS: IN_FEATURES: ["p3", "p4", "p5", "p6"] NAME: "DecoupledStandardROIHeads" NUM_CLASSES: 8 ROI_BOX_HEAD: NAME: "FastRCNNConvFCHead" NUM_FC: 2 POOLER_RESOLUTION: 7 ANCHOR_GENERATOR: SIZES: [ [ 32 ], [ 64 ], [ 128 ], [ 256 ], [ 512 ] ] # One size for each in feature map ASPECT_RATIOS: [ [ 0.5, 1.0, 2.0 ] ] # Three aspect RPN: IN_FEATURES: ["p3", "p4", "p5", "p6", "p7"] PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level PRE_NMS_TOPK_TEST: 1000 # Per FPN level POST_NMS_TOPK_TRAIN: 1000 POST_NMS_TOPK_TEST: 1000 PROPOSAL_GENERATOR: NAME: "TLRPN" INPUT: FORMAT: "RGB" MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) MIN_SIZE_TEST: 800 MAX_SIZE_TEST: 1280 MAX_SIZE_TRAIN: 1280 DATASETS: TRAIN: ("cityscapes_trainval",) TEST: ("cityscapes_test",) SOLVER: STEPS: (3999, ) MAX_ITER: 4000 # 4 epochs WARMUP_ITERS: 100 CHECKPOINT_PERIOD: 1000 IMS_PER_BATCH: 8 BASE_LR: 0.01 LR_SCHEDULER_NAME: "ExponentialLR" GAMMA: 0.1 TEST: EVAL_PERIOD: 500 VIS_PERIOD: 20 VERSION: 2 ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/config/retinanet_R_101_FPN_voc.yaml ================================================ MODEL: META_ARCHITECTURE: "DecoupledRetinaNet" WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" BACKBONE: NAME: "build_retinanet_resnet_fpn_backbone" MASK_ON: False RESNETS: DEPTH: 101 OUT_FEATURES: [ "res4", "res5" ] ANCHOR_GENERATOR: SIZES: !!python/object/apply:eval [ "[[x, x * 2**(1.0/3), x * 2**(2.0/3) ] for x in [64, 128, 256, 512 ]]" ] RETINANET: NUM_CLASSES: 20 IN_FEATURES: ["p4", "p5", "p6", "p7"] FPN: IN_FEATURES: ["res4", "res5"] INPUT: MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, ) MIN_SIZE_TEST: 608 MAX_SIZE_TRAIN: 1166 DATASETS: TRAIN: ('voc_2007_trainval', 'voc_2012_trainval') TEST: ('voc_2007_test',) SOLVER: STEPS: (3999, ) MAX_ITER: 4000 # 16 epochs WARMUP_ITERS: 100 CHECKPOINT_PERIOD: 1000 IMS_PER_BATCH: 8 BASE_LR: 0.001 TEST: EVAL_PERIOD: 500 VIS_PERIOD: 20 VERSION: 2 ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/d_adapt.py ================================================ """ `D-adapt: Decoupled Adaptation for Cross-Domain Object Detection `_. @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import logging import os import argparse import sys import pprint import numpy as np import torch from torch.nn.parallel import DistributedDataParallel from detectron2.engine import default_writers, launch from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer import detectron2.utils.comm as comm from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping from detectron2.data import ( build_detection_train_loader, build_detection_test_loader, MetadataCatalog ) from detectron2.utils.events import EventStorage from detectron2.evaluation import inference_on_dataset sys.path.append('../../../..') import tllib.alignment.d_adapt.modeling.meta_arch as models from tllib.alignment.d_adapt.proposal import ProposalGenerator, ProposalMapper, PersistentProposalList, flatten from tllib.alignment.d_adapt.feedback import get_detection_dataset_dicts, DatasetMapper sys.path.append('..') import utils import category_adaptation import bbox_adaptation def generate_proposals(model, num_classes, dataset_names, cache_root, cfg): """Generate foreground proposals and background proposals from `model` and save them to the disk""" fg_proposals_list = PersistentProposalList(os.path.join(cache_root, "{}_fg.json".format(dataset_names[0]))) bg_proposals_list = PersistentProposalList(os.path.join(cache_root, "{}_bg.json".format(dataset_names[0]))) if not (fg_proposals_list.load() and bg_proposals_list.load()): for dataset_name in dataset_names: data_loader = build_detection_test_loader(cfg, dataset_name, mapper=ProposalMapper(cfg, False)) generator = ProposalGenerator(num_classes=num_classes) fg_proposals_list_data, bg_proposals_list_data = inference_on_dataset(model, data_loader, generator) fg_proposals_list.extend(fg_proposals_list_data) bg_proposals_list.extend(bg_proposals_list_data) fg_proposals_list.flush() bg_proposals_list.flush() return fg_proposals_list, bg_proposals_list def generate_category_labels(prop, category_adaptor, cache_filename): """Generate category labels for each proposals in `prop` and save them to the disk""" prop_w_category = PersistentProposalList(cache_filename) if not prop_w_category.load(): for p in prop: prop_w_category.append(p) data_loader_test = category_adaptor.prepare_test_data(flatten(prop_w_category)) predictions = category_adaptor.predict(data_loader_test) for p in prop_w_category: p.pred_classes = np.array([predictions.popleft() for _ in range(len(p))]) prop_w_category.flush() return prop_w_category def generate_bounding_box_labels(prop, bbox_adaptor, class_names, cache_filename): """Generate bounding box labels for each proposals in `prop` and save them to the disk""" prop_w_bbox = PersistentProposalList(cache_filename) if not prop_w_bbox.load(): # remove (predicted) background proposals for p in prop: keep_indices = (0 <= p.pred_classes) & (p.pred_classes < len(class_names)) prop_w_bbox.append(p[keep_indices]) data_loader_test = bbox_adaptor.prepare_test_data(flatten(prop_w_bbox)) predictions = bbox_adaptor.predict(data_loader_test) for p in prop_w_bbox: p.pred_boxes = np.array([predictions.popleft() for _ in range(len(p))]) prop_w_bbox.flush() return prop_w_bbox def train(model, logger, cfg, args, args_cls, args_box): model.train() distributed = comm.get_world_size() > 1 if distributed: model_without_parallel = model.module else: model_without_parallel = model # define optimizer and lr scheduler params = [] for module, lr in model_without_parallel.get_parameters(cfg.SOLVER.BASE_LR): params.extend( get_default_optimizer_params( module, base_lr=lr, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, ) ) optimizer = maybe_add_gradient_clipping(cfg, torch.optim.SGD)( params, lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV, weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) scheduler = utils.build_lr_scheduler(cfg, optimizer) # resume from the last checkpoint checkpointer = DetectionCheckpointer( model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler ) checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume) start_iter = 0 max_iter = cfg.SOLVER.MAX_ITER periodic_checkpointer = PeriodicCheckpointer( checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter ) writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else [] # generate proposals from detector classes = MetadataCatalog.get(args.targets[0]).thing_classes cache_proposal_root = os.path.join(cfg.OUTPUT_DIR, "cache", "proposal") prop_t_fg, prop_t_bg = generate_proposals(model, len(classes), args.targets, cache_proposal_root, cfg) prop_s_fg, prop_s_bg = generate_proposals(model, len(classes), args.sources, cache_proposal_root, cfg) model = model.to(torch.device('cpu')) # train the category adaptor category_adaptor = category_adaptation.CategoryAdaptor(classes, os.path.join(cfg.OUTPUT_DIR, "cls"), args_cls) if not category_adaptor.load_checkpoint(): data_loader_source = category_adaptor.prepare_training_data(prop_s_fg + prop_s_bg, True) data_loader_target = category_adaptor.prepare_training_data(prop_t_fg + prop_t_bg, False) data_loader_validation = category_adaptor.prepare_validation_data(prop_t_fg + prop_t_bg) category_adaptor.fit(data_loader_source, data_loader_target, data_loader_validation) # generate category labels for each proposals cache_feedback_root = os.path.join(cfg.OUTPUT_DIR, "cache", "feedback") prop_t_fg = generate_category_labels( prop_t_fg, category_adaptor, os.path.join(cache_feedback_root, "{}_fg.json".format(args.targets[0])) ) prop_t_bg = generate_category_labels( prop_t_bg, category_adaptor, os.path.join(cache_feedback_root, "{}_bg.json".format(args.targets[0])) ) category_adaptor.model.to(torch.device("cpu")) if args.bbox_refine: # train the bbox adaptor bbox_adaptor = bbox_adaptation.BoundingBoxAdaptor(classes, os.path.join(cfg.OUTPUT_DIR, "bbox"), args_box) if not bbox_adaptor.load_checkpoint(): data_loader_source = bbox_adaptor.prepare_training_data(prop_s_fg, True) data_loader_target = bbox_adaptor.prepare_training_data(prop_t_fg, False) data_loader_validation = bbox_adaptor.prepare_validation_data(prop_t_fg) bbox_adaptor.validate_baseline(data_loader_validation) bbox_adaptor.fit(data_loader_source, data_loader_target, data_loader_validation) # generate bounding box labels for each proposals cache_feedback_root = os.path.join(cfg.OUTPUT_DIR, "cache", "feedback_bbox") prop_t_fg_refined = generate_bounding_box_labels( prop_t_fg, bbox_adaptor, classes, os.path.join(cache_feedback_root, "{}_fg.json".format(args.targets[0])) ) prop_t_bg_refined = generate_bounding_box_labels( prop_t_bg, bbox_adaptor, classes, os.path.join(cache_feedback_root, "{}_bg.json".format(args.targets[0])) ) prop_t_fg += prop_t_fg_refined prop_t_bg += prop_t_bg_refined bbox_adaptor.model.to(torch.device("cpu")) if args.reduce_proposals: # remove proposals prop_t_bg_new = [] for p in prop_t_bg: keep_indices = p.pred_classes == len(classes) prop_t_bg_new.append(p[keep_indices]) prop_t_bg = prop_t_bg_new prop_t_fg_new = [] for p in prop_t_fg: prop_t_fg_new.append(p[:20]) prop_t_fg = prop_t_fg_new model = model.to(torch.device(cfg.MODEL.DEVICE)) # Data loading code train_source_dataset = get_detection_dataset_dicts(args.sources) train_source_loader = build_detection_train_loader(dataset=train_source_dataset, cfg=cfg) train_target_dataset = get_detection_dataset_dicts(args.targets, proposals_list=prop_t_fg+prop_t_bg) mapper = DatasetMapper(cfg, precomputed_proposal_topk=1000, augmentations=utils.build_augmentation(cfg, True)) train_target_loader = build_detection_train_loader(dataset=train_target_dataset, cfg=cfg, mapper=mapper, total_batch_size=cfg.SOLVER.IMS_PER_BATCH) # training the object detector logger.info("Starting training from iteration {}".format(start_iter)) with EventStorage(start_iter) as storage: for data_s, data_t, iteration in zip(train_source_loader, train_target_loader, range(start_iter, max_iter)): storage.iter = iteration optimizer.zero_grad() # compute losses and gradient on source domain loss_dict_s = model(data_s) losses_s = sum(loss_dict_s.values()) assert torch.isfinite(losses_s).all(), loss_dict_s loss_dict_reduced_s = {"{}_s".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_s).items()} losses_reduced_s = sum(loss for loss in loss_dict_reduced_s.values()) losses_s.backward() # compute losses and gradient on target domain loss_dict_t = model(data_t, labeled=False) losses_t = sum(loss_dict_t.values()) assert torch.isfinite(losses_t).all() loss_dict_reduced_t = {"{}_t".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_t).items()} (losses_t * args.trade_off).backward() if comm.is_main_process(): storage.put_scalars(total_loss_s=losses_reduced_s, **loss_dict_reduced_s, **loss_dict_reduced_t) # do SGD step optimizer.step() storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) scheduler.step() # evaluate on validation set if ( cfg.TEST.EVAL_PERIOD > 0 and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0 and iteration != max_iter - 1 ): utils.validate(model, logger, cfg, args) comm.synchronize() if iteration - start_iter > 5 and ( (iteration + 1) % 20 == 0 or iteration == max_iter - 1 ): for writer in writers: writer.write() periodic_checkpointer.step(iteration) def main(args, args_cls, args_box): logger = logging.getLogger("detectron2") cfg = utils.setup(args) # dataset args.sources = utils.build_dataset(args.sources[::2], args.sources[1::2]) args.targets = utils.build_dataset(args.targets[::2], args.targets[1::2]) args.test = utils.build_dataset(args.test[::2], args.test[1::2]) # create model model = models.__dict__[cfg.MODEL.META_ARCHITECTURE](cfg, finetune=args.finetune) model.to(torch.device(cfg.MODEL.DEVICE)) logger.info("Model:\n{}".format(model)) if args.eval_only: DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) return utils.validate(model, logger, cfg, args) distributed = comm.get_world_size() > 1 if distributed: model = DistributedDataParallel( model, device_ids=[comm.get_local_rank()], broadcast_buffers=False ) train(model, logger, cfg, args, args_cls, args_box) # evaluate on validation set return utils.validate(model, logger, cfg, args) if __name__ == "__main__": args_cls, argv = category_adaptation.CategoryAdaptor.get_parser().parse_known_args() print("Category Adaptation Args:") pprint.pprint(args_cls) args_box, argv = bbox_adaptation.BoundingBoxAdaptor.get_parser().parse_known_args(args=argv) print("Bounding Box Adaptation Args:") pprint.pprint(args_box) parser = argparse.ArgumentParser(add_help=True) # dataset parameters parser.add_argument('-s', '--sources', nargs='+', help='source domain(s)') parser.add_argument('-t', '--targets', nargs='+', help='target domain(s)') parser.add_argument('--test', nargs='+', help='test domain(s)') # model parameters parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller learning rate for backbone') parser.add_argument( "--resume", action="store_true", help="Whether to attempt to resume from the checkpoint directory. " "See documentation of `DefaultTrainer.resume_or_load()` for what it means.", ) parser.add_argument('--trade-off', default=1., type=float, help='trade-off hyper-parameter for losses on target domain') parser.add_argument('--bbox-refine', action='store_true', help='whether perform bounding box refinement') parser.add_argument('--reduce-proposals', action='store_true', help='whether remove some low-quality proposals.' 'Helpful for RetinaNet') # training parameters parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") parser.add_argument("--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)") # PyTorch still may leave orphan processes in multi-gpu training. # Therefore we use a deterministic way to obtain port, # so that users are aware of orphan processes by seeing the port occupied. port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 parser.add_argument( "--dist-url", default="tcp://127.0.0.1:{}".format(port), help="initialization URL for pytorch distributed backend. See " "https://pytorch.org/docs/stable/distributed.html for details.", ) parser.add_argument( "opts", help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. " "See config references at " "https://detectron2.readthedocs.io/modules/config.html#config-references", default=None, nargs=argparse.REMAINDER, ) args, argv = parser.parse_known_args(argv) print("Detection Args:") pprint.pprint(args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args, args_cls, args_box), ) ================================================ FILE: examples/domain_adaptation/object_detection/d_adapt/d_adapt.sh ================================================ # ResNet101 Based Faster RCNN: Faster RCNN: VOC->Clipart # 44.8 pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \ -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \ --finetune --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0 # 47.9 pretrained_models=logs/faster_rcnn_R_101_C4/voc2clipart/phase1/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \ -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \ --finetune --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0 # 49.0 pretrained_models=logs/faster_rcnn_R_101_C4/voc2clipart/phase2/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.2 \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \ -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \ --finetune --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2clipart/phase3 MODEL.WEIGHTS ${pretrained_models} SEED 0 # ResNet101 Based Faster RCNN: Faster RCNN: VOC->WaterColor # 54.1 pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012 \ -t WaterColor ../datasets/watercolor --test WaterColorTest ../datasets/watercolor --finetune --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2watercolor/phase1 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0 # 57.5 pretrained_models=logs/faster_rcnn_R_101_C4/voc2watercolor/phase1/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012 \ -t WaterColor ../datasets/watercolor --test WaterColorTest ../datasets/watercolor --finetune --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2watercolor/phase2 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0 # ResNet101 Based Faster RCNN: Faster RCNN: VOC->Comic # 39.7 pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012 \ -t Comic ../datasets/comic --test ComicTest ../datasets/comic --finetune --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2comic/phase1 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0 # 41.0 pretrained_models=logs/faster_rcnn_R_101_C4/voc2comic/phase1/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --confidence-ratio-c 0.1 \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007Partial ../datasets/VOC2007 VOC2012Partial ../datasets/VOC2012 \ -t Comic ../datasets/comic --test ComicTest ../datasets/comic --finetune --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/voc2comic/phase2 MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS ${pretrained_models} SEED 0 # ResNet101 Based Faster RCNN: Cityscapes -> Foggy Cityscapes # 40.1 pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/cityscapes2foggy/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \ --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0 # 42.4 pretrained_models=logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase1/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.1 \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \ --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/cityscapes2foggy/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0 # VGG Based Faster RCNN: Cityscapes -> Foggy Cityscapes # 33.3 pretrained_models=../logs/source_only/faster_rcnn_vgg_16/cityscapes2foggy/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \ --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0 # 37.0 pretrained_models=logs/faster_rcnn_vgg_16/cityscapes2foggy/phase1/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.1 \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \ --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0 # 38.9 pretrained_models=logs/faster_rcnn_vgg_16/cityscapes2foggy/phase2/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 4 --max-train-c 20 --ignored-scores-c 0.05 0.5 --confidence-ratio-c 0.2 \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s Cityscapes ../datasets/cityscapes_in_voc -t FoggyCityscapes ../datasets/foggy_cityscapes_in_voc/ \ --test FoggyCityscapesTest ../datasets/foggy_cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_vgg_16/cityscapes2foggy/phase3 MODEL.WEIGHTS ${pretrained_models} SEED 0 # ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car # 51.9 pretrained_models=../logs/source_only/faster_rcnn_R_101_C4/sim10k2cityscapes_car/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 8 --ignored-scores-c 0.05 0.5 --bottleneck-dim-c 256 --bottleneck-dim-b 256 \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s Sim10kCar ../datasets/sim10k -t CityscapesCar ../datasets/cityscapes_in_voc/ \ --test CityscapesCarTest ../datasets/cityscapes_in_voc/ --finetune --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_R_101_C4/sim10k2cityscapes_car/phase1 MODEL.ROI_HEADS.NUM_CLASSES 1 MODEL.WEIGHTS ${pretrained_models} SEED 0 # VGG Based Faster RCNN: Sim10k -> Cityscapes Car # 49.3 pretrained_models=../logs/source_only/faster_rcnn_vgg_16/sim10k2cityscapes_car/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --workers-c 8 --ignored-scores-c 0.05 0.5 --bottleneck-dim-c 256 --bottleneck-dim-b 256 \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s Sim10kCar ../datasets/sim10k -t CityscapesCar ../datasets/cityscapes_in_voc/ \ --test CityscapesCarTest ../datasets/cityscapes_in_voc/ --finetune --trade-off 0.5 --bbox-refine \ OUTPUT_DIR logs/faster_rcnn_vgg_16/sim10k2cityscapes_car/phase1 MODEL.ROI_HEADS.NUM_CLASSES 1 MODEL.WEIGHTS ${pretrained_models} SEED 0 # RetinaNet: VOC->Clipart # 44.7 pretrained_models=../logs/source_only/retinanet_R_101_FPN/voc2clipart/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --remove-bg \ --config-file config/retinanet_R_101_FPN_voc.yaml \ -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \ -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \ --finetune --bbox-refine \ OUTPUT_DIR logs/retinanet_R_101_FPN/voc2clipart/phase1 MODEL.WEIGHTS ${pretrained_models} SEED 0 # 46.3 pretrained_models=logs/retinanet_R_101_FPN/voc2clipart/phase1/model_final.pth CUDA_VISIBLE_DEVICES=0 python d_adapt.py --remove-bg --confidence-ratio 0.1 \ --config-file config/retinanet_R_101_FPN_voc.yaml \ -s VOC2007 ../datasets/VOC2007 VOC2012 ../datasets/VOC2012 \ -t Clipart ../datasets/clipart --test Clipart ../datasets/clipart \ --finetune --bbox-refine \ OUTPUT_DIR logs/retinanet_R_101_FPN/voc2clipart/phase2 MODEL.WEIGHTS ${pretrained_models} SEED 0 ================================================ FILE: examples/domain_adaptation/object_detection/oracle.sh ================================================ # Faster RCNN: WaterColor CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s WaterColor datasets/watercolor -t WaterColor datasets/watercolor \ --test WaterColorTest datasets/watercolor --finetune \ OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/watercolor MODEL.ROI_HEADS.NUM_CLASSES 6 # Faster RCNN: Comic CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s Comic datasets/comic -t Comic datasets/comic \ --test ComicTest datasets/comic --finetune \ OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/comic MODEL.ROI_HEADS.NUM_CLASSES 6 # ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s FoggyCityscapes datasets/foggy_cityscapes_in_voc -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ --test FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \ OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/cityscapes2foggy # VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s FoggyCityscapes datasets/foggy_cityscapes_in_voc -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ --test FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \ OUTPUT_DIR logs/oracle/faster_rcnn_vgg_16/cityscapes2foggy # ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s CityscapesCar datasets/cityscapes_in_voc/ -t CityscapesCar datasets/cityscapes_in_voc/ \ --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \ OUTPUT_DIR logs/oracle/faster_rcnn_R_101_C4/cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1 # VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s CityscapesCar datasets/cityscapes_in_voc/ -t CityscapesCar datasets/cityscapes_in_voc/ \ --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \ OUTPUT_DIR logs/oracle/faster_rcnn_vgg_16/cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1 ================================================ FILE: examples/domain_adaptation/object_detection/prepare_cityscapes_to_voc.py ================================================ from pascal_voc_writer import Writer import matplotlib.pyplot as plt import numpy as np import os import json import glob import time from shutil import move, copy import tqdm classes = {'bicycle': 'bicycle', 'bus': 'bus', 'car': 'car', 'motorcycle': 'motorcycle', 'person': 'person', 'rider': 'rider', 'train': 'train', 'truck': 'truck'} classes_keys = list(classes.keys()) def make_dir(path): if not os.path.isdir(path): os.makedirs(path) #---------------------------------------------------------------------------------------------------------------- #convert polygon to bounding box #code from: #https://stackoverflow.com/questions/46335488/how-to-efficiently-find-the-bounding-box-of-a-collection-of-points #---------------------------------------------------------------------------------------------------------------- def polygon_to_bbox(polygon): x_coordinates, y_coordinates = zip(*polygon) return [min(x_coordinates), min(y_coordinates), max(x_coordinates), max(y_coordinates)] # -------------------------------------------- # read a json file and convert to voc format # -------------------------------------------- def read_json(file): # if no relevant objects found in the image, # don't save the xml for the image relevant_file = False data = [] with open(file, 'r') as f: file_data = json.load(f) for object in file_data['objects']: label, polygon = object['label'], object['polygon'] # process only if label found in voc if label in classes_keys: polygon = np.array([x for x in polygon]) bbox = polygon_to_bbox(polygon) data.append([classes[label]] + bbox) # if relevant objects found in image, set the flag to True if data: relevant_file = True return data, relevant_file #--------------------------- #function to save xml file #--------------------------- def save_xml(img_path, img_shape, data, save_path): writer = Writer(img_path,img_shape[1], img_shape[0]) for element in data: writer.addObject(element[0],element[1],element[2],element[3],element[4]) writer.save(save_path) def prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir): cityscapes_dir_gt = os.path.join(cityscapes_dir, 'gtFine') # ------------------------------------------ # reading json files from each subdirectory # ------------------------------------------ valid_files = [] trainval_files = [] test_files = [] # make Annotations target directory if already doesn't exist ann_dir = os.path.join(save_path, 'Annotations') make_dir(ann_dir) start = time.time() for category in os.listdir(cityscapes_dir_gt): # # no GT for test data # if category == 'test': # continue for city in tqdm.tqdm(os.listdir(os.path.join(cityscapes_dir_gt, category))): # read files files = glob.glob(os.path.join(cityscapes_dir, 'gtFine', category, city) + '/*.json') # process json files for file in files: data, relevant_file = read_json(file) if relevant_file: base_filename = os.path.basename(file)[:-21] xml_filepath = os.path.join(ann_dir, base_filename + '{}.xml'.format(suffix)) img_name = base_filename + '{}.png'.format(suffix) img_path = os.path.join(cityscapes_dir, image_dir, category, city, base_filename + '{}.png'.format(suffix)) img_shape = plt.imread(img_path).shape valid_files.append([img_path, img_name]) # make list of trainval and test files for voc format # lists will be stored in txt files trainval_files.append(img_name[:-4]) if category == 'train' else test_files.append(img_name[:-4]) # save xml file save_xml(os.path.join(image_dir, category, city, base_filename + '{}.png'.format(suffix)), img_shape, data, xml_filepath) end = time.time() - start print('Total Time taken: ', end) # ---------------------------- # copy files into target path # ---------------------------- images_savepath = os.path.join(save_path, 'JPEGImages') make_dir(images_savepath) start = time.time() for file in valid_files: copy(file[0], os.path.join(images_savepath, file[1])) print('Total Time taken: ', end) # --------------------------------------------- # create text files of trainval and test files # --------------------------------------------- textfiles_savepath = os.path.join(save_path, 'ImageSets', 'Main') make_dir(textfiles_savepath) traival_files_wr = [x + '\n' for x in trainval_files] test_files_wr = [x + '\n' for x in test_files] with open(os.path.join(textfiles_savepath, 'trainval.txt'), 'w') as f: f.writelines(traival_files_wr) with open(os.path.join(textfiles_savepath, 'test.txt'), 'w') as f: f.writelines(test_files_wr) if __name__ == '__main__': cityscapes_dir = 'datasets/cityscapes/' if not os.path.exists(cityscapes_dir): print("Please put cityscapes datasets in: {}".format(cityscapes_dir)) exit(0) save_path = 'datasets/cityscapes_in_voc/' suffix = "_leftImg8bit" image_dir = "leftImg8bit" prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir) save_path = 'datasets/foggy_cityscapes_in_voc/' suffix = "_leftImg8bit_foggy_beta_0.02" image_dir = "leftImg8bit_foggy" prepare_cityscapes_to_voc(cityscapes_dir, save_path, suffix, image_dir) ================================================ FILE: examples/domain_adaptation/object_detection/requirements.txt ================================================ mmcv timm prettytable pascal_voc_writer ================================================ FILE: examples/domain_adaptation/object_detection/source_only.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import logging import os import argparse import sys import torch from torch.nn.parallel import DistributedDataParallel from detectron2.engine import default_writers, launch from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer import detectron2.utils.comm as comm from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping from detectron2.solver import build_lr_scheduler from detectron2.data import ( build_detection_train_loader, get_detection_dataset_dicts, ) from detectron2.utils.events import EventStorage sys.path.append('../../..') import tllib.vision.models.object_detection.meta_arch as models import utils def train(model, logger, cfg, args): model.train() distributed = comm.get_world_size() > 1 if distributed: model_without_parallel = model.module else: model_without_parallel = model # define optimizer and lr scheduler params = [] for module, lr in model_without_parallel.get_parameters(cfg.SOLVER.BASE_LR): params.extend( get_default_optimizer_params( module, base_lr=lr, weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, ) ) optimizer = maybe_add_gradient_clipping(cfg, torch.optim.SGD)( params, lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV, weight_decay=cfg.SOLVER.WEIGHT_DECAY, ) scheduler = build_lr_scheduler(cfg, optimizer) # resume from the last checkpoint checkpointer = DetectionCheckpointer( model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=scheduler ) start_iter = ( checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume).get("iteration", -1) + 1 ) max_iter = cfg.SOLVER.MAX_ITER periodic_checkpointer = PeriodicCheckpointer( checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_iter=max_iter ) writers = default_writers(cfg.OUTPUT_DIR, max_iter) if comm.is_main_process() else [] # Data loading code train_source_dataset = get_detection_dataset_dicts(args.source) train_source_loader = build_detection_train_loader(dataset=train_source_dataset, cfg=cfg) # start training logger.info("Starting training from iteration {}".format(start_iter)) with EventStorage(start_iter) as storage: for data_s, iteration in zip(train_source_loader, range(start_iter, max_iter)): storage.iter = iteration # compute output _, loss_dict_s = model(data_s) losses_s = sum(loss_dict_s.values()) assert torch.isfinite(losses_s).all(), loss_dict_s loss_dict_reduced_s = {"{}_s".format(k): v.item() for k, v in comm.reduce_dict(loss_dict_s).items()} losses_reduced_s = sum(loss for loss in loss_dict_reduced_s.values()) if comm.is_main_process(): storage.put_scalars(total_loss_s=losses_reduced_s, **loss_dict_reduced_s) # compute gradient and do SGD step optimizer.zero_grad() losses_s.backward() optimizer.step() storage.put_scalar("lr", optimizer.param_groups[0]["lr"], smoothing_hint=False) scheduler.step() # evaluate on validation set if ( cfg.TEST.EVAL_PERIOD > 0 and (iteration + 1) % cfg.TEST.EVAL_PERIOD == 0 and iteration != max_iter - 1 ): utils.validate(model, logger, cfg, args) comm.synchronize() if iteration - start_iter > 5 and ( (iteration + 1) % 20 == 0 or iteration == max_iter - 1 ): for writer in writers: writer.write() periodic_checkpointer.step(iteration) def main(args): logger = logging.getLogger("detectron2") cfg = utils.setup(args) # dataset args.source = utils.build_dataset(args.source[::2], args.source[1::2]) args.target = utils.build_dataset(args.target[::2], args.target[1::2]) args.test = utils.build_dataset(args.test[::2], args.test[1::2]) # create model model = models.__dict__[cfg.MODEL.META_ARCHITECTURE](cfg, finetune=args.finetune) model.to(torch.device(cfg.MODEL.DEVICE)) logger.info("Model:\n{}".format(model)) if args.eval_only: DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume ) return utils.validate(model, logger, cfg, args) distributed = comm.get_world_size() > 1 if distributed: model = DistributedDataParallel( model, device_ids=[comm.get_local_rank()], broadcast_buffers=False ) train(model, logger, cfg, args) # evaluate on validation set return utils.validate(model, logger, cfg, args) if __name__ == "__main__": parser = argparse.ArgumentParser() # dataset parameters parser.add_argument('-s', '--source', nargs='+', help='source domain(s)') parser.add_argument('-t', '--target', nargs='+', help='target domain(s)') parser.add_argument('--test', nargs='+', help='test domain(s)') # model parameters parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller learning rate for backbone') parser.add_argument( "--resume", action="store_true", help="Whether to attempt to resume from the checkpoint directory. " "See documentation of `DefaultTrainer.resume_or_load()` for what it means.", ) # training parameters parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") parser.add_argument("--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)") # PyTorch still may leave orphan processes in multi-gpu training. # Therefore we use a deterministic way to obtain port, # so that users are aware of orphan processes by seeing the port occupied. port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 parser.add_argument( "--dist-url", default="tcp://127.0.0.1:{}".format(port), help="initialization URL for pytorch distributed backend. See " "https://pytorch.org/docs/stable/distributed.html for details.", ) parser.add_argument( "opts", help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. " "See config references at " "https://detectron2.readthedocs.io/modules/config.html#config-references", default=None, nargs=argparse.REMAINDER, ) args = parser.parse_args() print("Command Line Args:", args) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), ) ================================================ FILE: examples/domain_adaptation/object_detection/source_only.sh ================================================ # Faster RCNN: VOC->Clipart CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \ --test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \ OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2clipart # Faster RCNN: VOC->WaterColor, Comic CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_voc.yaml \ -s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 -t WaterColor datasets/watercolor Comic datasets/comic \ --test VOC2007PartialTest datasets/VOC2007 WaterColorTest datasets/watercolor ComicTest datasets/comic --finetune \ OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic MODEL.ROI_HEADS.NUM_CLASSES 6 # ResNet101 Based Faster RCNN: Cityscapes->Foggy Cityscapes CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s Cityscapes datasets/cityscapes_in_voc/ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ --test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \ OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/cityscapes2foggy # VGG16 Based Faster RCNN: Cityscapes->Foggy Cityscapes CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s Cityscapes datasets/cityscapes_in_voc/ -t FoggyCityscapes datasets/foggy_cityscapes_in_voc \ --test CityscapesTest datasets/cityscapes_in_voc/ FoggyCityscapesTest datasets/foggy_cityscapes_in_voc --finetune \ OUTPUT_DIR logs/source_only/faster_rcnn_vgg_16/cityscapes2foggy # ResNet101 Based Faster RCNN: Sim10k -> Cityscapes Car CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s Sim10kCar datasets/sim10k -t CityscapesCar datasets/cityscapes_in_voc/ \ --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \ OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1 # VGG16 Based Faster RCNN: Sim10k -> Cityscapes Car CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_vgg_16_cityscapes.yaml \ -s Sim10kCar datasets/sim10k -t CityscapesCar datasets/cityscapes_in_voc/ \ --test CityscapesCarTest datasets/cityscapes_in_voc/ --finetune \ OUTPUT_DIR logs/source_only/faster_rcnn_vgg_16/sim10k2cityscapes_car MODEL.ROI_HEADS.NUM_CLASSES 1 # Faster RCNN: GTA5 -> Cityscapes CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/faster_rcnn_R_101_C4_cityscapes.yaml \ -s GTA5 datasets/synscapes_detection -t Cityscapes datasets/cityscapes_in_voc/ \ --test CityscapesTest datasets/cityscapes_in_voc/ --finetune \ OUTPUT_DIR logs/source_only/faster_rcnn_R_101_C4/gta52cityscapes # RetinaNet: VOC->Clipart CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/retinanet_R_101_FPN_voc.yaml \ -s VOC2007 datasets/VOC2007 VOC2012 datasets/VOC2012 -t Clipart datasets/clipart \ --test VOC2007Test datasets/VOC2007 Clipart datasets/clipart --finetune \ OUTPUT_DIR logs/source_only/retinanet_R_101_FPN/voc2clipart # RetinaNet: VOC->WaterColor, Comic CUDA_VISIBLE_DEVICES=0 python source_only.py \ --config-file config/retinanet_R_101_FPN_voc.yaml \ -s VOC2007Partial datasets/VOC2007 VOC2012Partial datasets/VOC2012 -t WaterColor datasets/watercolor Comic datasets/comic \ --test VOC2007PartialTest datasets/VOC2007 WaterColorTest datasets/watercolor ComicTest datasets/comic --finetune \ OUTPUT_DIR logs/source_only/retinanet_R_101_FPN/voc2watercolor_comic MODEL.RETINANET.NUM_CLASSES 6 ================================================ FILE: examples/domain_adaptation/object_detection/utils.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import numpy as np import os import prettytable from typing import * from collections import OrderedDict, defaultdict import tempfile import logging import matplotlib as mpl import torch import torch.nn as nn import torchvision.transforms as T import detectron2.utils.comm as comm from detectron2.evaluation import PascalVOCDetectionEvaluator, inference_on_dataset from detectron2.evaluation.pascal_voc_evaluation import voc_eval from detectron2.config import get_cfg, CfgNode from detectron2.engine import default_setup from detectron2.data import ( build_detection_test_loader, ) from detectron2.data.transforms.augmentation import Augmentation from detectron2.data.transforms import BlendTransform, ColorTransform from detectron2.solver.lr_scheduler import LRMultiplier, WarmupParamScheduler from detectron2.utils.visualizer import Visualizer from detectron2.utils.colormap import random_color from fvcore.common.param_scheduler import * import timm import tllib.vision.datasets.object_detection as datasets import tllib.vision.models as models class PascalVOCDetectionPerClassEvaluator(PascalVOCDetectionEvaluator): """ Evaluate Pascal VOC style AP with per-class AP for Pascal VOC dataset. It contains a synchronization, therefore has to be called from all ranks. Note that the concept of AP can be implemented in different ways and may not produce identical results. This class mimics the implementation of the official Pascal VOC Matlab API, and should produce similar but not identical results to the official API. """ def evaluate(self): """ Returns: dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75". """ all_predictions = comm.gather(self._predictions, dst=0) if not comm.is_main_process(): return predictions = defaultdict(list) for predictions_per_rank in all_predictions: for clsid, lines in predictions_per_rank.items(): predictions[clsid].extend(lines) del all_predictions self._logger.info( "Evaluating {} using {} metric. " "Note that results do not use the official Matlab API.".format( self._dataset_name, 2007 if self._is_2007 else 2012 ) ) with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname: res_file_template = os.path.join(dirname, "{}.txt") aps = defaultdict(list) # iou -> ap per class for cls_id, cls_name in enumerate(self._class_names): lines = predictions.get(cls_id, [""]) with open(res_file_template.format(cls_name), "w") as f: f.write("\n".join(lines)) for thresh in range(50, 100, 5): rec, prec, ap = voc_eval( res_file_template, self._anno_file_template, self._image_set_path, cls_name, ovthresh=thresh / 100.0, use_07_metric=self._is_2007, ) aps[thresh].append(ap * 100) ret = OrderedDict() mAP = {iou: np.mean(x) for iou, x in aps.items()} ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]} for cls_name, ap in zip(self._class_names, aps[50]): ret["bbox"][cls_name] = ap return ret def validate(model, logger, cfg, args): results = OrderedDict() for dataset_name in args.test: data_loader = build_detection_test_loader(cfg, dataset_name) evaluator = PascalVOCDetectionPerClassEvaluator(dataset_name) results_i = inference_on_dataset(model, data_loader, evaluator) results[dataset_name] = results_i if comm.is_main_process(): logger.info(results_i) table = prettytable.PrettyTable(["class", "AP"]) for class_name, ap in results_i["bbox"].items(): table.add_row([class_name, ap]) logger.info(table.get_string()) if len(results) == 1: results = list(results.values())[0] return results def build_dataset(dataset_categories, dataset_roots): """ Give a sequence of dataset class name and a sequence of dataset root directory, return a sequence of built datasets """ dataset_lists = [] for dataset_category, root in zip(dataset_categories, dataset_roots): dataset_lists.append(datasets.__dict__[dataset_category](root).name) return dataset_lists def rgb2gray(rgb): return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])[:, :, np.newaxis].repeat(3, axis=2).astype(rgb.dtype) class Grayscale(Augmentation): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): super().__init__() self._init(locals()) self._transform = T.Grayscale() def get_transform(self, image): return ColorTransform(lambda x: rgb2gray(x)) def build_augmentation(cfg, is_train): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. Returns: list[Augmentation] """ import detectron2.data.transforms as T if is_train: min_size = cfg.INPUT.MIN_SIZE_TRAIN max_size = cfg.INPUT.MAX_SIZE_TRAIN sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING else: min_size = cfg.INPUT.MIN_SIZE_TEST max_size = cfg.INPUT.MAX_SIZE_TEST sample_style = "choice" augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)] if is_train and cfg.INPUT.RANDOM_FLIP != "none": augmentation.append( T.RandomFlip( horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", vertical=cfg.INPUT.RANDOM_FLIP == "vertical", ) ) augmentation.append( T.RandomApply(T.AugmentationList( [ T.RandomContrast(0.6, 1.4), T.RandomBrightness(0.6, 1.4), T.RandomSaturation(0.6, 1.4), T.RandomLighting(0.1) ] ), prob=0.8) ) augmentation.append( T.RandomApply(Grayscale(), prob=0.2) ) return augmentation def setup(args): """ Create configs and perform basic setups. """ cfg = get_cfg() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() default_setup( cfg, args ) # if you don't like any of the default setup, write your own setup code return cfg def build_lr_scheduler( cfg: CfgNode, optimizer: torch.optim.Optimizer ) -> torch.optim.lr_scheduler._LRScheduler: """ Build a LR scheduler from config. """ name = cfg.SOLVER.LR_SCHEDULER_NAME if name == "WarmupMultiStepLR": steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER] if len(steps) != len(cfg.SOLVER.STEPS): logger = logging.getLogger(__name__) logger.warning( "SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. " "These values will be ignored." ) sched = MultiStepParamScheduler( values=[cfg.SOLVER.GAMMA ** k for k in range(len(steps) + 1)], milestones=steps, num_updates=cfg.SOLVER.MAX_ITER, ) elif name == "WarmupCosineLR": sched = CosineParamScheduler(1, 0) elif name == "ExponentialLR": sched = ExponentialParamScheduler(1, cfg.SOLVER.GAMMA) return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER) else: raise ValueError("Unknown LR scheduler: {}".format(name)) sched = WarmupParamScheduler( sched, cfg.SOLVER.WARMUP_FACTOR, cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, cfg.SOLVER.WARMUP_METHOD, ) return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER) def get_model_names(): return sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) + timm.list_models() def get_model(model_name, pretrain=True): if model_name in models.__dict__: # load models from common.vision.models backbone = models.__dict__[model_name](pretrained=pretrain) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=pretrain) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() return backbone class VisualizerWithoutAreaSorting(Visualizer): """ Visualizer in detectron2 draw instances according to their area's order. This visualizer removes sorting code to avoid that boxes with lower confidence cover boxes with higher confidence. """ def __init__(self, *args, flip=False, **kwargs): super(VisualizerWithoutAreaSorting, self).__init__(*args, **kwargs) self.flip = flip def overlay_instances( self, *, boxes=None, labels=None, masks=None, keypoints=None, assigned_colors=None, alpha=1 ): """ Args: boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, or a :class:`RotatedBoxes`, or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format for the N objects in a single image, labels (list[str]): the text to be displayed for each instance. masks (masks-like object): Supported types are: * :class:`detectron2.structures.PolygonMasks`, :class:`detectron2.structures.BitMasks`. * list[list[ndarray]]: contains the segmentation masks for all objects in one image. The first level of the list corresponds to individual instances. The second level to all the polygon that compose the instance, and the third level to the polygon coordinates. The third level should have the format of [x0, y0, x1, y1, ..., xn, yn] (n >= 3). * list[ndarray]: each ndarray is a binary mask of shape (H, W). * list[dict]: each dict is a COCO-style RLE. keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), where the N is the number of instances and K is the number of keypoints. The last dimension corresponds to (x, y, visibility or score). assigned_colors (list[matplotlib.colors]): a list of colors, where each color corresponds to each mask or box in the image. Refer to 'matplotlib.colors' for full list of formats that the colors are accepted in. Returns: output (VisImage): image object with visualizations. """ num_instances = None if boxes is not None: boxes = self._convert_boxes(boxes) num_instances = len(boxes) if masks is not None: masks = self._convert_masks(masks) if num_instances: assert len(masks) == num_instances else: num_instances = len(masks) if keypoints is not None: if num_instances: assert len(keypoints) == num_instances else: num_instances = len(keypoints) keypoints = self._convert_keypoints(keypoints) if labels is not None: assert len(labels) == num_instances if assigned_colors is None: assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] if num_instances == 0: return self.output if boxes is not None and boxes.shape[1] == 5: return self.overlay_rotated_instances( boxes=boxes, labels=labels, assigned_colors=assigned_colors ) for i in range(num_instances): color = assigned_colors[i] if boxes is not None: self.draw_box(boxes[i], edge_color=color) if masks is not None: for segment in masks[i].polygons: self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) if labels is not None: # first get a box if boxes is not None: x0, y0, x1, y1 = boxes[i] text_pos = (x0 - 3, y0) # if drawing boxes, put text on the box corner. horiz_align = "left" elif masks is not None: # skip small mask without polygon if len(masks[i].polygons) == 0: continue x0, y0, x1, y1 = masks[i].bbox() # draw text in the center (defined by median) when box is not drawn # median is less sensitive to outliers. text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] horiz_align = "center" else: continue # drawing the box confidence for keypoints isn't very useful. # for small objects, draw text at the side to avoid occlusion instance_area = (y1 - y0) * (x1 - x0) if ( instance_area < 1000 * self.output.scale or y1 - y0 < 40 * self.output.scale ): if y1 >= self.output.height - 5: text_pos = (x1, y0) else: text_pos = (x0, y1) height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) lighter_color = self._change_color_brightness(color, brightness_factor=0.7) font_size = ( np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 1 * self._default_font_size ) if self.flip: text_pos = (x1 - 3, y0 - 30) # if drawing boxes, put text on the box corner. self.draw_text( labels[i], text_pos, color=lighter_color, horizontal_alignment=horiz_align, font_size=font_size, ) # draw keypoints if keypoints is not None: for keypoints_per_instance in keypoints: self.draw_and_connect_keypoints(keypoints_per_instance) return self.output def draw_box(self, box_coord, alpha=1, edge_color="g", line_style="-"): x0, y0, x1, y1 = box_coord width = x1 - x0 height = y1 - y0 linewidth = max(self._default_font_size / 4, 3) self.output.ax.add_patch( mpl.patches.Rectangle( (x0, y0), width, height, fill=False, edgecolor=edge_color, linewidth=linewidth * self.output.scale, alpha=alpha, linestyle=line_style, ) ) return self.output ================================================ FILE: examples/domain_adaptation/object_detection/visualize.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import os import argparse import sys import torch from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import ( build_detection_test_loader, MetadataCatalog ) from detectron2.data import detection_utils from detectron2.engine import default_setup, launch from detectron2.utils.visualizer import ColorMode sys.path.append('../../..') import tllib.vision.models.object_detection.meta_arch as models import utils def visualize(cfg, args, model): for dataset_name in args.test: data_loader = build_detection_test_loader(cfg, dataset_name) # create folder dirname = os.path.join(args.save_path, dataset_name) os.makedirs(dirname, exist_ok=True) metadata = MetadataCatalog.get(dataset_name) n_current = 0 # switch to eval mode model.eval() with torch.no_grad(): for batch in data_loader: if n_current >= args.n_visualizations: break batch_predictions = model(batch) for per_image, predictions in zip(batch, batch_predictions): instances = predictions["instances"].to(torch.device("cpu")) # only visualize boxes with highest confidence instances = instances[0: args.n_bboxes] # only visualize boxes with confidence exceeding the threshold instances = instances[instances.scores > args.threshold] # visualize in reverse order of confidence index = [i for i in range(len(instances))] index.reverse() instances = instances[index] img = per_image["image"].permute(1, 2, 0).cpu().detach().numpy() img = detection_utils.convert_image_to_rgb(img, cfg.INPUT.FORMAT) # scale pred_box to original resolution ori_height, ori_width, _ = img.shape height, width = instances.image_size ratio = ori_width / width for i in range(len(instances.pred_boxes)): instances.pred_boxes[i].scale(ratio, ratio) # save original image visualizer = utils.VisualizerWithoutAreaSorting(img, metadata=metadata, instance_mode=ColorMode.IMAGE) output = visualizer.draw_instance_predictions(predictions=instances) filepath = str(n_current) + ".png" filepath = os.path.join(dirname, filepath) output.save(filepath) n_current += 1 if n_current >= args.n_visualizations: break def setup(args): """ Create configs and perform basic setups. """ cfg = get_cfg() cfg.merge_from_file(args.config_file) cfg.merge_from_list(args.opts) cfg.freeze() default_setup( cfg, args ) # if you don't like any of the default setup, write your own setup code return cfg def main(args): cfg = setup(args) meta_arch = cfg.MODEL.META_ARCHITECTURE model = models.__dict__[meta_arch](cfg, finetune=True) model.to(torch.device(cfg.MODEL.DEVICE)) DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=False ) visualize(cfg, args, model) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") parser.add_argument( "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)" ) # PyTorch still may leave orphan processes in multi-gpu training. # Therefore we use a deterministic way to obtain port, # so that users are aware of orphan processes by seeing the port occupied. port = 2 ** 15 + 2 ** 14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2 ** 14 parser.add_argument( "--dist-url", default="tcp://127.0.0.1:{}".format(port), help="initialization URL for pytorch distributed backend. See " "https://pytorch.org/docs/stable/distributed.html for details.", ) parser.add_argument( "opts", help="Modify config options by adding 'KEY VALUE' pairs at the end of the command. " "See config references at " "https://detectron2.readthedocs.io/modules/config.html#config-references", default=None, nargs=argparse.REMAINDER, ) parser.add_argument('--test', nargs='+', help='test domain(s)') parser.add_argument('--save-path', type=str, help='where to save visualization results ') parser.add_argument('--n-visualizations', default=100, type=int, help='maximum number of images to visualize (default: 100)') parser.add_argument('--threshold', default=0.5, type=float, help='confidence threshold of bounding boxes to visualize (default: 0.5)') parser.add_argument('--n-bboxes', default=10, type=int, help='maximum number of bounding boxes to visualize in a single image (default: 10)') args = parser.parse_args() print("Command Line Args:", args) args.test = utils.build_dataset(args.test[::2], args.test[1::2]) launch( main, args.num_gpus, num_machines=args.num_machines, machine_rank=args.machine_rank, dist_url=args.dist_url, args=(args,), ) ================================================ FILE: examples/domain_adaptation/object_detection/visualize.sh ================================================ # Source Only Faster RCNN: VOC->Clipart CUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \ --test Clipart datasets/clipart --save-path visualizations/source_only/voc2clipart \ MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2clipart/model_final.pth # Source Only Faster RCNN: VOC->WaterColor, Comic CUDA_VISIBLE_DEVICES=0 python visualize.py --config-file config/faster_rcnn_R_101_C4_voc.yaml \ --test WaterColorTest datasets/watercolor ComicTest datasets/comic --save-path visualizations/source_only/voc2comic_watercolor \ MODEL.ROI_HEADS.NUM_CLASSES 6 MODEL.WEIGHTS logs/source_only/faster_rcnn_R_101_C4/voc2watercolor_comic/model_final.pth ================================================ FILE: examples/domain_adaptation/openset_domain_adaptation/README.md ================================================ # Open-set Domain Adaptation for Image Classification ## Installation It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results. Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You also need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset Following datasets can be downloaded automatically: - [Office31](https://www.cc.gatech.edu/~judy/domainadapt/) - [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html) - [VisDA2017](http://ai.bu.edu/visda-2017/) ## Supported Methods Supported methods include: - [Open Set Domain Adaptation (OSBP)](https://arxiv.org/abs/1804.10427) ## Experiment and Results The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to train DANN on Office31, use the following script ```shell script # Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50. # Assume you have put the datasets under the path `data/office-31`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W ``` **Notations** - ``Origin`` means the accuracy reported by the original paper. - ``Avg`` is the accuracy reported by `TLlib`. - ``ERM`` refers to the model trained with data from the source domain. We report ``HOS`` used in [ROS (ECCV 2020)](http://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610409.pdf) to better measure the abilities of different open set domain adaptation algorithms. We report the best ``HOS`` in all epochs. DANN (baseline model) will degrade performance as training progresses, thus the final ``HOS`` will be much lower than reported. In contrast, OSBP will improve performance stably. ### Office-31 H-Score on ResNet-50 | Methods | Avg | A → W | D → W | W → D | A → D | D → A | W → A | |-------------|------|-------|-------|-------|-------|-------|-------| | ERM | 75.9 | 67.7 | 85.7 | 91.4 | 72.1 | 68.4 | 67.8 | | DANN | 80.4 | 81.4 | 89.1 | 92.0 | 82.5 | 66.7 | 70.4 | | OSBP | 87.8 | 90.7 | 96.4 | 97.5 | 88.7 | 77.0 | 76.7 | ### Office-Home HOS on ResNet-50 | Methods | Origin | Avg | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr | |-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------| | Source Only | / | 59.8 | 55.2 | 65.2 | 71.4 | 52.8 | 59.6 | 65.2 | 55.8 | 44.8 | 68.0 | 63.8 | 49.4 | 68.0 | | DANN | / | 64.8 | 55.2 | 65.2 | 71.4 | 52.8 | 59.6 | 65.2 | 55.8 | 44.8 | 68.0 | 63.8 | 49.4 | 68.0 | | OSBP | 64.7 | 68.6 | 62.0 | 70.8 | 76.5 | 66.4 | 68.8 | 73.8 | 65.8 | 57.1 | 75.4 | 70.6 | 60.6 | 75.9 | ### VisDA-2017 performance on ResNet-50 | Methods | HOS | OS | OS* | UNK | bcycl | bus | car | mcycl | train | truck | |-------------|------|------|------|------|-------|------|------|-------|-------|-------| | Source Only | 42.6 | 37.6 | 34.7 | 55.1 | 42.6 | 6.4 | 30.5 | 67.1 | 84.0 | 0.2 | | DANN | 57.8 | 50.4 | 45.6 | 78.9 | 20.1 | 71.4 | 29.5 | 74.4 | 67.8 | 10.4 | | OSBP | 75.4 | 67.3 | 62.9 | 94.3 | 63.7 | 75.9 | 49.6 | 74.4 | 86.2 | 27.3 | ## Citation If you use these methods in your research, please consider citing. ``` @InProceedings{OSBP, author = {Saito, Kuniaki and Yamamoto, Shohei and Ushiku, Yoshitaka and Harada, Tatsuya}, title = {Open Set Domain Adaptation by Backpropagation}, booktitle = {ECCV}, year = {2018} } ``` ================================================ FILE: examples/domain_adaptation/openset_domain_adaptation/dann.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.modules.classifier import Classifier from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy, ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to( device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, classifier, args) print(acc1) return # start training best_h_score = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set h_score = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if h_score > best_h_score: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_h_score = max(h_score, best_h_score) print("best_h_score = {:3.1f}".format(best_h_score)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) h_score = validate(test_loader, classifier, args) print("test_h_score = {:3.1f}".format(h_score)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) def validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float: batch_time = AverageMeter('Time', ':6.3f') classes = val_loader.dataset.classes confmat = ConfusionMatrix(len(classes)) progress = ProgressMeter( len(val_loader), [batch_time], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output = model(images) softmax_output = F.softmax(output, dim=1) softmax_output[:, -1] = args.threshold # measure accuracy and record loss confmat.update(target, softmax_output.argmax(1)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) acc_global, accs, iu = confmat.compute() all_acc = torch.mean(accs).item() * 100 known = torch.mean(accs[:-1]).item() * 100 unknown = accs[-1].item() * 100 h_score = 2 * known * unknown / (known + unknown) if args.per_class_eval: print(confmat.format(classes)) print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}' .format(all=all_acc, known=known, unknown=unknown, h_score=h_score)) return h_score if __name__ == '__main__': parser = argparse.ArgumentParser(description='DANN for Openset Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain') parser.add_argument('-t', '--target', help='target domain') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--threshold', default=0.8, type=float, help='When class confidence is less than the given threshold, ' 'model will output "unknown" (default: 0.5)') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.002, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='dann', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/openset_domain_adaptation/dann.sh ================================================ #!/usr/bin/env bash # Office31 CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_A2W CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_D2W CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_W2D CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_A2D CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_D2A CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --threshold 0.9 --log logs/dann/Office31_W2A # Office-Home CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/dann/OfficeHome_Rw2Pr # VisDA-2017 CUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ --epochs 30 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/dann/VisDA2017_S2R ================================================ FILE: examples/domain_adaptation/openset_domain_adaptation/erm.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.classifier import Classifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy, ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, _, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, pool_layer=pool_layer).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # using shuffled val loader val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(val_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, classifier, args) print(acc1) return # start training best_h_score = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set h_score = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if h_score > best_h_score: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_h_score = max(h_score, best_h_score) print("best_h_score = {:3.1f}".format(best_h_score)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) h_score = validate(test_loader, classifier, args) print("test_h_score = {:3.1f}".format(h_score)) logger.close() def train(train_source_iter: ForeverDataIterator, model: Classifier, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) cls_loss = F.cross_entropy(y_s, labels_s) loss = cls_loss cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) def validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float: batch_time = AverageMeter('Time', ':6.3f') classes = val_loader.dataset.classes confmat = ConfusionMatrix(len(classes)) progress = ProgressMeter( len(val_loader), [batch_time], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output = model(images) softmax_output = F.softmax(output, dim=1) softmax_output[:, -1] = args.threshold # measure accuracy and record loss confmat.update(target, softmax_output.argmax(1)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) acc_global, accs, iu = confmat.compute() all_acc = torch.mean(accs).item() * 100 known = torch.mean(accs[:-1]).item() * 100 unknown = accs[-1].item() * 100 h_score = 2 * known * unknown / (known + unknown) if args.per_class_eval: print(confmat.format(classes)) print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}' .format(all=all_acc, known=known, unknown=unknown, h_score=h_score)) return h_score if __name__ == '__main__': parser = argparse.ArgumentParser(description='Source Only for Openset Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain') parser.add_argument('-t', '--target', help='target domain') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--threshold', default=0.8, type=float, help='When class confidence is less than the given threshold, ' 'model will output "unknown" (default: 0.5)') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='src_only', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/openset_domain_adaptation/erm.sh ================================================ #!/usr/bin/env bash # Office31 CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_A2W CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_D2W CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_W2D CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_A2D CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_D2A CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/erm/Office31_W2A # Office-Home CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/erm/OfficeHome_Rw2Pr # VisDA-2017 CUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ --epochs 30 -i 500 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/erm/VisDA2017_S2R ================================================ FILE: examples/domain_adaptation/openset_domain_adaptation/osbp.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.alignment.osbp import ImageClassifier as Classifier, UnknownClassBinaryCrossEntropy from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy, ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to(device) print(classifier) unknown_bce = UnknownClassBinaryCrossEntropy(t=0.5) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = validate(test_loader, classifier, args) print(acc1) return # start training best_h_score = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, unknown_bce, optimizer, lr_scheduler, epoch, args) # evaluate on validation set h_score = validate(val_loader, classifier, args) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if h_score > best_h_score: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_h_score = max(h_score, best_h_score) print("best_h_score = {:3.1f}".format(best_h_score)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) h_score = validate(test_loader, classifier, args) print("test_h_score = {:3.1f}".format(h_score)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: Classifier, unknown_bce: UnknownClassBinaryCrossEntropy, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') trans_losses = AverageMeter('Trans Loss', ':3.2f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, trans_losses, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, _ = model(x_s, grad_reverse=False) y_t, _ = model(x_t, grad_reverse=True) cls_loss = F.cross_entropy(y_s, labels_s) trans_loss = unknown_bce(y_t) loss = cls_loss + trans_loss cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) trans_losses.update(trans_loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_t.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) def validate(val_loader: DataLoader, model: Classifier, args: argparse.Namespace) -> float: batch_time = AverageMeter('Time', ':6.3f') classes = val_loader.dataset.classes confmat = ConfusionMatrix(len(classes)) progress = ProgressMeter( len(val_loader), [batch_time], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output = model(images) # measure accuracy and record loss confmat.update(target, output.argmax(1)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) acc_global, accs, iu = confmat.compute() all_acc = torch.mean(accs).item() * 100 known = torch.mean(accs[:-1]).item() * 100 unknown = accs[-1].item() * 100 h_score = 2 * known * unknown / (known + unknown) if args.per_class_eval: print(confmat.format(classes)) print(' * All {all:.3f} Known {known:.3f} Unknown {unknown:.3f} H-score {h_score:.3f}' .format(all=all_acc, known=known, unknown=unknown, h_score=h_score)) return h_score if __name__ == '__main__': parser = argparse.ArgumentParser(description='OSBP for Openset Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain') parser.add_argument('-t', '--target', help='target domain') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='osbp', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/openset_domain_adaptation/osbp.sh ================================================ #!/usr/bin/env bash # Office31 CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_A2W CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_D2W CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_W2D CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_A2D CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_D2A CUDA_VISIBLE_DEVICES=0 python osbp.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/osbp/Office31_W2A # Office-Home CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python osbp.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/osbp/OfficeHome_Rw2Pr # VisDA-2017 CUDA_VISIBLE_DEVICES=0 python osbp.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ --epochs 30 -i 1000 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/osbp/VisDA2017_S2R ================================================ FILE: examples/domain_adaptation/openset_domain_adaptation/utils.py ================================================ import sys import timm import torch.nn as nn import torchvision.transforms as T sys.path.append('../../..') import tllib.vision.datasets.openset as datasets from tllib.vision.datasets.openset import default_open_set as open_set import tllib.vision.models as models from tllib.vision.transforms import ResizeImage def get_model_names(): return sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) + timm.list_models() def get_model(model_name): if model_name in models.__dict__: # load models from tllib.vision.models backbone = models.__dict__[model_name](pretrained=True) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=True) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') backbone.copy_head = backbone.get_classifier except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() backbone.copy_head = lambda x: x.head return backbone def get_dataset_names(): return sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None): if train_target_transform is None: train_target_transform = train_source_transform # load datasets from tllib.vision.datasets dataset = datasets.__dict__[dataset_name] source_dataset = open_set(dataset, source=True) target_dataset = open_set(dataset, source=False) train_source_dataset = source_dataset(root=root, task=source, download=True, transform=train_source_transform) train_target_dataset = target_dataset(root=root, task=target, download=True, transform=train_target_transform) val_dataset = target_dataset(root=root, task=target, download=True, transform=val_transform) if dataset_name == 'DomainNet': test_dataset = target_dataset(root=root, task=target, split='test', download=True, transform=val_transform) else: test_dataset = val_dataset class_names = train_source_dataset.classes num_classes = len(class_names) return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False): """ resizing mode: - default: resize the image to 256 and take a random resized crop of size 224; - cen.crop: resize the image to 256 and take the center crop of size 224; - res: resize the image to 224; - res.|crop: resize the image to 256 and take a random crop of size 224; - res.sma|crop: resize the image keeping its aspect ratio such that the smaller side is 256, then take a random crop of size 224; – inc.crop: “inception crop” from (Szegedy et al., 2015); – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224. """ if resizing == 'default': transform = T.Compose([ ResizeImage(256), T.RandomResizedCrop(224) ]) elif resizing == 'cen.crop': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224) ]) elif resizing == 'ran.crop': transform = T.Compose([ ResizeImage(256), T.RandomCrop(224) ]) elif resizing == 'res.': transform = T.Resize(224) elif resizing == 'res.|crop': transform = T.Compose([ T.Resize((256, 256)), T.RandomCrop(224) ]) elif resizing == "res.sma|crop": transform = T.Compose([ T.Resize(256), T.RandomCrop(224) ]) elif resizing == 'inc.crop': transform = T.RandomResizedCrop(224) elif resizing == 'cif.crop': transform = T.Compose([ T.Resize((224, 224)), T.Pad(28), T.RandomCrop(224), ]) else: raise NotImplementedError(resizing) transforms = [transform] if random_horizontal_flip: transforms.append(T.RandomHorizontalFlip()) if random_color_jitter: transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)) transforms.extend([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return T.Compose(transforms) def get_val_transform(resizing='default'): """ resizing mode: - default: resize the image to 256 and take the center crop of size 224; – res.: resize the image to 224 – res.|crop: resize the image such that the smaller side is of size 256 and then take a central crop of size 224. """ if resizing == 'default': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), ]) elif resizing == 'res.': transform = T.Resize((224, 224)) elif resizing == 'res.|crop': transform = T.Compose([ T.Resize(256), T.CenterCrop(224), ]) else: raise NotImplementedError(resizing) return T.Compose([ transform, T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/README.md ================================================ # Partial Domain Adaptation for Image Classification ## Installation It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results. Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You also need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset Following datasets can be downloaded automatically: - [Office31](https://www.cc.gatech.edu/~judy/domainadapt/) - [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html) - [VisDA2017](http://ai.bu.edu/visda-2017/) ## Supported Methods Supported methods include: - [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818) - [Partial Adversarial Domain Adaptation (PADA)](https://arxiv.org/abs/1808.04205) - [Importance Weighted Adversarial Nets (IWAN)](https://arxiv.org/abs/1803.09210) - [Adaptive Feature Norm (AFN)](https://arxiv.org/pdf/1811.07456v2.pdf) ## Experiment and Results The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to train DANN on Office31, use the following script ```shell script # Train a DANN on Office-31 Amazon -> Webcam task using ResNet 50. # Assume you have put the datasets under the path `data/office-31`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 1 --log logs/dann/Office31_A2W ``` **Notations** - ``Origin`` means the accuracy reported by the original paper. - ``Avg`` is the accuracy reported by `TLlib`. - ``ERM`` refers to the model trained with data from the source domain. - ``Oracle`` refers to the model trained with data from the target domain. We found that the accuracies of adversarial methods (including DANN) are not stable even after the random seed is fixed, thus we repeat running adversarial methods on *Office-31* and *VisDA-2017* for three times and report their average accuracy. ### Office-31 accuracy on ResNet-50 | Methods | Origin | Avg | A → W | D → W | W → D | A → D | D → A | W → A | |-------------|--------|------|-------|-------|-------|-------|-------|-------| | ERM | 75.6 | 90.1 | 78.3 | 98.3 | 99.4 | 87.3 | 88.5 | 88.8 | 84.0 | | DANN | 43.4 | 82.4 | 60.0 | 94.9 | 98.1 | 71.3 | 84.9 | 85.0 | | PADA | 92.7 | 93.8 | 86.4 | 100.0 | 100.0 | 87.3 | 93.8 | 95.4 | | IWAN | 94.7 | 94.8 | 91.2 | 99.7 | 99.4 | 89.8 | 94.2 | 94.3 | | AFN | / | 93.1 | 87.8 | 95.6 | 99.4 | 87.9 | 93.9 | 94.1 | ### Office-Home accuracy on ResNet-50 | Methods | Origin | Avg | Ar → Cl | Ar → Pr | Ar → Rw | Cl → Ar | Cl → Pr | Cl → Rw | Pr → Ar | Pr → Cl | Pr → Rw | Rw → Ar | Rw → Cl | Rw → Pr | |-------------|--------|------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------|---------| | ERM | 53.7 | 60.1 | 42.0 | 66.9 | 78.5 | 56.4 | 55.2 | 65.4 | 57.9 | 36.0 | 75.5 | 68.7 | 43.6 | 74.8 | | DANN | 47.4 | 57.0 | 46.2 | 59.3 | 76.9 | 47.0 | 47.4 | 56.4 | 51.6 | 38.8 | 72.1 | 68.0 | 46.1 | 74.2 | | PADA | 62.1 | 65.9 | 52.9 | 69.3 | 82.8 | 59.0 | 57.5 | 66.4 | 66.0 | 41.7 | 82.5 | 78.0 | 50.2 | 84.1 | | IWAN | 63.6 | 71.3 | 59.2 | 76.6 | 84.0 | 67.8 | 66.7 | 69.2 | 73.3 | 55.0 | 83.9 | 79.0 | 58.3 | 82.2 | | AFN | 71.8 | 72.6 | 59.2 | 76.7 | 82.8 | 72.5 | 74.5 | 76.8 | 72.5 | 56.7 | 80.8 | 77.0 | 60.5 | 81.6 | ### VisDA-2017 accuracy on ResNet-50 | Methods | Origin | Mean | plane | bcycl | bus | car | horse | knife | Avg | |-------------|--------|------|-------|-------|------|------|-------|-------|------| | ERM | 45.3 | 50.9 | 59.2 | 31.3 | 68.7 | 73.2 | 69.3 | 3.4 | 60.0 | | DANN | 51.0 | 55.9 | 88.4 | 34.1 | 72.1 | 50.7 | 61.9 | 27.8 | 57.1 | | PADA | 53.5 | 60.5 | 89.4 | 35.1 | 72.5 | 69.2 | 86.7 | 10.1 | 66.8 | | IWAN | / | 61.5 | 89.2 | 57.0 | 61.5 | 55.2 | 80.1 | 25.7 | 66.8 | | AFN | 67.6 | 61.0 | 79.1 | 62.7 | 73.9 | 49.6 | 79.6 | 21.0 | 64.1 | ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{DANN, author = {Ganin, Yaroslav and Lempitsky, Victor}, Booktitle = {ICML}, Title = {Unsupervised domain adaptation by backpropagation}, Year = {2015} } @InProceedings{PADA, author = {Zhangjie Cao and Lijia Ma and Mingsheng Long and Jianmin Wang}, title = {Partial Adversarial Domain Adaptation}, booktitle = {ECCV}, year = {2018} } @InProceedings{IWAN, author = {Jing Zhang and Zewei Ding and Wanqing Li and Philip Ogunbona}, title = {Importance Weighted Adversarial Nets for Partial Domain Adaptation}, booktitle = {CVPR}, year = {2018} } @InProceedings{AFN, author = {Xu, Ruijia and Li, Guanbin and Yang, Jihan and Lin, Liang}, title = {Larger Norm More Transferable: An Adaptive Feature Norm Approach for Unsupervised Domain Adaptation}, booktitle = {ICCV}, year = {2019} } ``` ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/afn.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import sys import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.normalization.afn import AdaptiveFeatureNorm, ImageClassifier from tllib.modules.entropy import entropy import tllib.vision.models as models from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None backbone = models.__dict__[args.arch](pretrained=True) classifier = ImageClassifier(backbone, train_source_dataset.num_classes, args.num_blocks, bottleneck_dim=args.bottleneck_dim, dropout_p=args.dropout_p, pool_layer=pool_layer).to(device) adaptive_feature_norm = AdaptiveFeatureNorm(args.delta).to(device) # define optimizer # the learning rate is fixed according to origin paper optimizer = SGD(classifier.get_parameters(), args.lr, weight_decay=args.weight_decay) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, adaptive_feature_norm, optimizer, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, adaptive_feature_norm: AdaptiveFeatureNorm, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':3.1f') data_time = AverageMeter('Data', ':3.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') norm_losses = AverageMeter('Norm Loss', ':3.2f') src_feature_norm = AverageMeter('Source Feature Norm', ':3.2f') tgt_feature_norm = AverageMeter('Target Feature Norm', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, cls_losses, norm_losses, src_feature_norm, tgt_feature_norm, cls_accs, tgt_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # norm loss norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t) loss = cls_loss + norm_loss * args.trade_off_norm # using entropy minimization if args.trade_off_entropy: y_t = F.softmax(y_t, dim=1) entropy_loss = entropy(y_t, reduction='mean') loss += entropy_loss * args.trade_off_entropy # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # update statistics cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] cls_losses.update(cls_loss.item(), x_s.size(0)) norm_losses.update(norm_loss.item(), x_s.size(0)) src_feature_norm.update(f_s.norm(p=2, dim=1).mean().item(), x_s.size(0)) tgt_feature_norm.update(f_t.norm(p=2, dim=1).mean().item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='AFN for Partial Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain') parser.add_argument('-t', '--target', help='target domain') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('-n', '--num-blocks', default=1, type=int, help='Number of basic blocks for classifier') parser.add_argument('--bottleneck-dim', default=1000, type=int, help='Dimension of bottleneck') parser.add_argument('--dropout-p', default=0.5, type=float, help='Dropout probability') # training parameters parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default: 5e-4)', dest='weight_decay') parser.add_argument('--trade-off-norm', default=0.05, type=float, help='the trade-off hyper-parameter for norm loss') parser.add_argument('--trade-off-entropy', default=None, type=float, help='the trade-off hyper-parameter for entropy loss') parser.add_argument('-r', '--delta', default=1, type=float, help='Increment for L2 norm') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='afn', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/afn.sh ================================================ #!/usr/bin/env bash # Office31 CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2W CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t W -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2W CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2D CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s A -t D -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_A2D CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s D -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_D2A CUDA_VISIBLE_DEVICES=0 python afn.py data/office31 -d Office31 -s W -t A -a resnet50 --trade-off-entropy 0.1 --epochs 20 --seed 1 --log logs/afn/Office31_W2A # Office-Home CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python afn.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/afn/OfficeHome_Rw2Pr # VisDA-2017 CUDA_VISIBLE_DEVICES=0 python afn.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 -r 0.3 -b 36 \ --epochs 30 -i 1000 --seed 0 --per-class-eval --train-resizing cen.crop --log logs/afn/VisDA2017 ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/dann.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.modules.classifier import Classifier from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None if args.data == 'ImageNetCaltech': classifier = Classifier(backbone, num_classes, head=backbone.copy_head(), pool_layer=pool_layer).to(device) else: classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s) transfer_loss = domain_adv(f_s, f_t) domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='DANN for Partial Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain') parser.add_argument('-t', '--target', help='target domain') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.002, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay',default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='dann', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/dann.sh ================================================ #!/usr/bin/env bash # Office31 CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_A2W CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_D2W CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_W2D CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_A2D CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_D2A CUDA_VISIBLE_DEVICES=0 python dann.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/dann/Office31_W2A # Office-Home CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python dann.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 --seed 0 --log logs/dann/OfficeHome_Rw2Pr # VisDA-2017 CUDA_VISIBLE_DEVICES=0 python dann.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ --epochs 5 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/dann/VisDA2017_S2R # ImageNet-Caltech CUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \ --epochs 5 --seed 0 --log logs/dann/I2C CUDA_VISIBLE_DEVICES=0 python dann.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \ --epochs 5 --seed 0 --bottleneck-dim 2048 --log logs/dann/C2I ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/erm.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.classifier import Classifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, _, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None head = backbone.copy_head() if args.data == 'ImageNetCaltech' else None classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, head=head).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # using shuffled val loader val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(val_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, model: Classifier, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_s = x_s.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) cls_loss = F.cross_entropy(y_s, labels_s) loss = cls_loss cls_acc = accuracy(y_s, labels_s)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Source Only for Partial Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain') parser.add_argument('-t', '--target', help='target domain') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.0003, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='src_only', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/erm.sh ================================================ #!/usr/bin/env bash # Office31 CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2W CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2W CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2D CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_A2D CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_D2A CUDA_VISIBLE_DEVICES=0 python erm.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 5 --seed 0 --log logs/erm/Office31_W2A # Office-Home CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 5 -i 500 --seed 0 --log logs/erm/OfficeHome_Rw2Pr # VisDA-2017 CUDA_VISIBLE_DEVICES=0 python erm.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ --epochs 10 -i 500 --seed 0 --per-class-eval --log logs/erm/VisDA2017_S2R # ImageNet-Caltech CUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \ --epochs 20 --seed 0 -i 2000 --log logs/erm/I2C CUDA_VISIBLE_DEVICES=0 python erm.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \ --epochs 20 --seed 0 -i 2000 --log logs/erm/C2I ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/iwan.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import sys import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.classifier import Classifier from tllib.modules.entropy import entropy from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.reweight.iwan import ImportanceWeightModule, ImageClassifier from tllib.alignment.dann import DomainAdversarialLoss from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None if args.data == 'ImageNetCaltech': classifier = Classifier(backbone, num_classes, head=backbone.copy_head(), pool_layer=pool_layer).to(device) else: classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device) # define domain classifier D, D_0 D = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024, batch_norm=False).to(device) D_0 = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024, batch_norm=False).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + D.get_parameters() + D_0.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv_D = DomainAdversarialLoss(D).to(device) domain_adv_D_0 = DomainAdversarialLoss(D_0).to(device) # define importance weight module importance_weight_module = ImportanceWeightModule(D, train_target_dataset.partial_classes_idx) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv_D, domain_adv_D_0, importance_weight_module, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv_D: DomainAdversarialLoss, domain_adv_D_0: DomainAdversarialLoss, importance_weight_module, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') domain_accs_D = AverageMeter('Domain Acc for D', ':3.1f') domain_accs_D_0 = AverageMeter('Domain Acc for D_0', ':3.1f') partial_classes_weights = AverageMeter('Partial Weight', ':3.2f') non_partial_classes_weights = AverageMeter('Non-Partial Weight', ':3.2f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, tgt_accs, domain_accs_D, domain_accs_D_0, partial_classes_weights, non_partial_classes_weights], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv_D.train() domain_adv_D_0.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) # classification loss cls_loss = F.cross_entropy(y_s, labels_s) # domain adversarial loss for D adv_loss_D = domain_adv_D(f_s.detach(), f_t.detach()) # get importance weights w_s = importance_weight_module.get_importance_weight(f_s) # domain adversarial loss for D_0 adv_loss_D_0 = domain_adv_D_0(f_s, f_t, w_s=w_s) # entropy loss y_t = F.softmax(y_t, dim=1) entropy_loss = entropy(y_t, reduction='mean') loss = cls_loss + 1.5 * args.trade_off * adv_loss_D + \ args.trade_off * adv_loss_D_0 + args.gamma * entropy_loss # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) domain_accs_D.update(domain_adv_D.domain_discriminator_accuracy, x_s.size(0)) domain_accs_D_0.update(domain_adv_D_0.domain_discriminator_accuracy, x_s.size(0)) # debug: output class weight averaged on the partial classes and non-partial classes respectively partial_class_weight, non_partial_classes_weight = \ importance_weight_module.get_partial_classes_weight(w_s, labels_s) partial_classes_weights.update(partial_class_weight.item(), x_s.size(0)) non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='IWAN for Partial Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of source (and target) dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain') parser.add_argument('-t', '--target', help='target domain') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('--gamma', default=0.1, type=float, help='the trade-off hyper-parameter for entropy loss(default: 0.1)') parser.add_argument('--trade-off', default=3, type=float, help='the trade-off hyper-parameter for transfer loss(default: 3))') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=10, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='iwan', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/iwan.sh ================================================ #!/usr/bin/env bash # Office31 CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s A -t W -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_A2W CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s D -t W -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_D2W CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s W -t D -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_W2D CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s A -t D -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_A2D CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s D -t A -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_D2A CUDA_VISIBLE_DEVICES=0 python iwan.py data/office31 -d Office31 -s W -t A -a resnet50 --lr 0.0003 --seed 0 --log logs/iwan/Office31_W2A # Office-Home CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python iwan.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --seed 0 --log logs/iwan/OfficeHome_Rw2Pr # VisDA-2017 CUDA_VISIBLE_DEVICES=0 python iwan.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ --lr 0.0003 --seed 0 --train-resizing cen.crop --per-class-eval --log logs/iwan/VisDA2017_S2R # ImageNet-Caltech CUDA_VISIBLE_DEVICES=0 python iwan.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \ --seed 0 --log logs/iwan/I2C CUDA_VISIBLE_DEVICES=0 python iwan.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \ --seed 0 --bottleneck-dim 2048 --log logs/iwan/C2I ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/pada.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.modules.classifier import Classifier from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier from tllib.reweight.pada import AutomaticUpdateClassWeightModule from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import collect_feature, tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=False) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ utils.get_dataset(args.data, args.root, args.source, args.target, train_transform, val_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None if args.data == 'ImageNetCaltech': classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, head=backbone.copy_head()).to(device) else: classifier = ImageClassifier(backbone, num_classes, args.bottleneck_dim, pool_layer=pool_layer).to(device) domain_discri = DomainDiscriminator(in_feature=classifier.features_dim, hidden_size=1024).to(device) class_weight_module = AutomaticUpdateClassWeightModule(args.class_weight_update_steps, train_target_loader, classifier, num_classes, device, args.temperature, train_target_dataset.partial_classes_idx) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) # define loss function domain_adv = DomainAdversarialLoss(domain_discri).to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = collect_feature(train_source_loader, feature_extractor, device) target_feature = collect_feature(train_target_loader, feature_extractor, device) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0. for epoch in range(args.epochs): # train for one epoch train(train_source_iter, train_target_iter, classifier, domain_adv, class_weight_module, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test_acc1 = {:3.1f}".format(acc1)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model: ImageClassifier, domain_adv: DomainAdversarialLoss, class_weight_module: AutomaticUpdateClassWeightModule, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':5.2f') data_time = AverageMeter('Data', ':5.2f') losses = AverageMeter('Loss', ':6.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') domain_accs = AverageMeter('Domain Acc', ':3.1f') tgt_accs = AverageMeter('Tgt Acc', ':3.1f') partial_classes_weights = AverageMeter('Partial Weight', ':3.1f') non_partial_classes_weights = AverageMeter('Non-partial Weight', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs, domain_accs, tgt_accs, partial_classes_weights, non_partial_classes_weights], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() domain_adv.train() end = time.time() for i in range(args.iters_per_epoch): x_s, labels_s = next(train_source_iter) x_t, labels_t = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output x = torch.cat((x_s, x_t), dim=0) y, f = model(x) y_s, y_t = y.chunk(2, dim=0) f_s, f_t = f.chunk(2, dim=0) cls_loss = F.cross_entropy(y_s, labels_s, class_weight_module.get_class_weight_for_cross_entropy_loss()) w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss(labels_s) transfer_loss = domain_adv(f_s, f_t, w_s, w_t) class_weight_module.step() partial_classes_weight, non_partial_classes_weight = class_weight_module.get_partial_classes_weight() domain_acc = domain_adv.domain_discriminator_accuracy loss = cls_loss + transfer_loss * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] tgt_acc = accuracy(y_t, labels_t)[0] losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) domain_accs.update(domain_acc.item(), x_s.size(0)) tgt_accs.update(tgt_acc.item(), x_s.size(0)) partial_classes_weights.update(partial_classes_weight.item(), x_s.size(0)) non_partial_classes_weights.update(non_partial_classes_weight.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='PADA for Partial Domain Adaptation') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of source (and target) dataset') parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: Office31)') parser.add_argument('-s', '--source', help='source domain') parser.add_argument('-t', '--target', help='target domain') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet18)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--bottleneck-dim', default=256, type=int, help='Dimension of bottleneck') parser.add_argument('-u', '--class-weight-update-steps', default=500, type=int, help='Number of steps to update class weight once') parser.add_argument('--temperature', default=0.1, type=float, help='temperature for softmax when calculating class weight') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=0.002, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay',default=1e-3, type=float, metavar='W', help='weight decay (default: 1e-3)', dest='weight_decay') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument('--per-class-eval', action='store_true', help='whether output per-class accuracy during evaluation') parser.add_argument("--log", type=str, default='pada', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/pada.sh ================================================ #!/usr/bin/env bash # Office31 CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s A -t W -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_A2W CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s D -t W -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_D2W CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s W -t D -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_W2D CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s A -t D -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_A2D CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s D -t A -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_D2A CUDA_VISIBLE_DEVICES=0 python pada.py data/office31 -d Office31 -s W -t A -a resnet50 --epochs 20 --seed 0 --log logs/pada/Office31_W2A # Office-Home CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Cl CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Pr CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Ar -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Ar2Rw CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Ar CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Pr CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Cl -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Cl2Rw CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Ar CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Cl CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Pr -t Rw -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Pr2Rw CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Ar -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Ar CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Cl -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Cl CUDA_VISIBLE_DEVICES=0 python pada.py data/office-home -d OfficeHome -s Rw -t Pr -a resnet50 --epochs 30 --temperature 0.01 --seed 0 --log logs/pada/OfficeHome_Rw2Pr # VisDA-2017 CUDA_VISIBLE_DEVICES=0 python pada.py data/visda-2017 -d VisDA2017 -s Synthetic -t Real -a resnet50 \ --epochs 20 --seed 0 -u 500 -i 500 --train-resizing cen.crop --trade-off 0.4 --per-class-eval --log logs/pada/VisDA2017_S2R # ImageNet-Caltech CUDA_VISIBLE_DEVICES=0 python pada.py data/ImageNetCaltech -d ImageNetCaltech -s I -t C -a resnet50 \ --epochs 20 --seed 0 --lr 0.003 --temperature 0.01 -u 2000 -i 2000 --log logs/pada/I2C CUDA_VISIBLE_DEVICES=0 python pada.py data/ImageNetCaltech -d CaltechImageNet -s C -t I -a resnet50 \ --epochs 20 --seed 0 --lr 0.003 --temperature 0.01 -u 2000 -i 2000 --bottleneck-dim 2048 --log logs/pada/C2I ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/requirements.txt ================================================ timm ================================================ FILE: examples/domain_adaptation/partial_domain_adaptation/utils.py ================================================ import sys import time import timm import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as T sys.path.append('../../..') import tllib.vision.datasets.partial as datasets from tllib.vision.datasets.partial import default_partial as partial import tllib.vision.models as models from tllib.vision.transforms import ResizeImage from tllib.utils.metric import accuracy, ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter def get_model_names(): return sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) + timm.list_models() def get_model(model_name): if model_name in models.__dict__: # load models from tllib.vision.models backbone = models.__dict__[model_name](pretrained=True) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=True) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') backbone.copy_head = backbone.get_classifier except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() backbone.copy_head = lambda x: x.head return backbone def get_dataset_names(): return sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) def get_dataset(dataset_name, root, source, target, train_source_transform, val_transform, train_target_transform=None): if train_target_transform is None: train_target_transform = train_source_transform # load datasets from tllib.vision.datasets dataset = datasets.__dict__[dataset_name] partial_dataset = partial(dataset) train_source_dataset = dataset(root=root, task=source, download=True, transform=train_source_transform) train_target_dataset = partial_dataset(root=root, task=target, download=True, transform=train_target_transform) val_dataset = partial_dataset(root=root, task=target, download=True, transform=val_transform) if dataset_name == 'DomainNet': test_dataset = partial_dataset(root=root, task=target, split='test', download=True, transform=val_transform) else: test_dataset = val_dataset class_names = train_source_dataset.classes num_classes = len(class_names) return train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, class_names def validate(val_loader, model, args, device) -> float: batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1], prefix='Test: ') # switch to evaluate mode model.eval() if args.per_class_eval: confmat = ConfusionMatrix(len(args.class_names)) else: confmat = None with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output = model(images) loss = F.cross_entropy(output, target) # measure accuracy and record loss acc1, = accuracy(output, target, topk=(1,)) if confmat: confmat.update(target, output.argmax(1)) losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) if confmat: print(confmat.format(args.class_names)) return top1.avg def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False): """ resizing mode: - default: resize the image to 256 and take a random resized crop of size 224; - cen.crop: resize the image to 256 and take the center crop of size 224; - res: resize the image to 224; - res.|crop: resize the image to 256 and take a random crop of size 224; - res.sma|crop: resize the image keeping its aspect ratio such that the smaller side is 256, then take a random crop of size 224; – inc.crop: “inception crop” from (Szegedy et al., 2015); – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224. """ if resizing == 'default': transform = T.Compose([ ResizeImage(256), T.RandomResizedCrop(224) ]) elif resizing == 'cen.crop': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224) ]) elif resizing == 'res.': transform = T.Resize(224) elif resizing == 'res.|crop': transform = T.Compose([ T.Resize((256, 256)), T.RandomCrop(224) ]) elif resizing == "res.sma|crop": transform = T.Compose([ T.Resize(256), T.RandomCrop(224) ]) elif resizing == 'inc.crop': transform = T.RandomResizedCrop(224) elif resizing == 'cif.crop': transform = T.Compose([ T.Resize((224, 224)), T.Pad(28), T.RandomCrop(224), ]) else: raise NotImplementedError(resizing) transforms = [transform] if random_horizontal_flip: transforms.append(T.RandomHorizontalFlip()) if random_color_jitter: transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)) transforms.extend([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return T.Compose(transforms) def get_val_transform(resizing='default'): """ resizing mode: - default: resize the image to 256 and take the center crop of size 224; – res.: resize the image to 224 – res.|crop: resize the image such that the smaller side is of size 256 and then take a central crop of size 224. """ if resizing == 'default': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), ]) elif resizing == 'res.': transform = T.Resize((224, 224)) elif resizing == 'res.|crop': transform = T.Compose([ T.Resize(256), T.CenterCrop(224), ]) else: raise NotImplementedError(resizing) return T.Compose([ transform, T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ================================================ FILE: examples/domain_adaptation/re_identification/README.md ================================================ # Unsupervised Domain Adaptation for Person Re-Identification ## Installation It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results. Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You also need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset Following datasets can be downloaded automatically: - [Market1501](http://zheng-lab.cecs.anu.edu.au/Project/project_reid.html) - [DukeMTMC](https://exposing.ai/duke_mtmc/) - [MSMT17](https://arxiv.org/pdf/1711.08565.pdf) ## Supported Methods Supported methods include: - [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf) - [Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification (MMT, 2020 ICLR)](https://arxiv.org/abs/2001.01526) - [Similarity Preserving Generative Adversarial Network (SPGAN, 2018 CVPR)](https://arxiv.org/pdf/1811.10551.pdf) ## Usage The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to train MMT on Market1501 -> DukeMTMC task, use the following script ```shell script # Train MMT on Market1501 -> DukeMTMC task using ResNet 50. # Assume you have put the datasets under the path `data/market1501` and `data/dukemtmc`, # or you are glad to download the datasets automatically from the Internet to this path # MMT involves two training steps: # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2DukeSeed0 CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2DukeSeed1 # step2: train mmt CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data -t DukeMTMC -a reid_resnet50 \ --pretrained-model-1-path logs/baseline/Market2DukeSeed0/checkpoints/best.pth \ --pretrained-model-2-path logs/baseline/Market2DukeSeed1/checkpoints/best.pth \ --finetune --seed 0 --log logs/mmt/Market2Duke ``` ### Experiment and Results In our experiments, we adopt modified resnet architecture from [MMT](https://arxiv.org/pdf/2001.01526.pdf>). For a fair comparison, we use standard cross entropy loss and triplet loss in all methods. For methods that utilize clustering algorithms, we adopt kmeans or DBSCAN and report both results. **Notations** - ``Avg`` means the mAP (mean average precision) reported by `TLlib`. - ``Baseline_Cluster`` represents the strong baseline in [MMT](https://arxiv.org/pdf/2001.01526.pdf>). ### Cross dataset mAP on ResNet-50 | Methods | Avg | Market2Duke | Duke2Market | Market2MSMT | MSMT2Market | Duke2MSMT | MSMT2Duke | |--------------------------|------|-------------|-------------|-------------|-------------|-----------|-----------| | Baseline | 27.1 | 32.4 | 31.4 | 8.2 | 36.7 | 11.0 | 43.1 | | IBN | 30.0 | 35.2 | 36.5 | 11.3 | 38.7 | 14.1 | 44.3 | | SPGAN | 30.7 | 34.4 | 35.4 | 14.1 | 40.2 | 16.1 | 43.8 | | Baseline_Cluster(kmeans) | 45.1 | 52.8 | 59.5 | 19.0 | 62.6 | 20.3 | 56.2 | | Baseline_Cluster(dbscan) | 54.9 | 62.5 | 73.5 | 25.2 | 77.9 | 25.3 | 65.0 | | MMT(kmeans) | 55.4 | 63.7 | 72.5 | 26.2 | 75.8 | 28.0 | 66.1 | | MMT(dbscan) | 60.0 | 68.2 | 80.0 | 28.2 | 82.5 | 31.2 | 70.0 | ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{IBN-Net, author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang}, title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net}, booktitle = {ECCV}, year = {2018} } @inproceedings{SPGAN, title={Image-image domain adaptation with preserved self-similarity and domain-dissimilarity for person re-identification}, author={Deng, Weijian and Zheng, Liang and Ye, Qixiang and Kang, Guoliang and Yang, Yi and Jiao, Jianbin}, booktitle={CVPR}, year={2018} } @inproceedings{MMT, title={Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification}, author={Yixiao Ge and Dapeng Chen and Hongsheng Li}, booktitle={ICLR}, year={2020}, } ``` ================================================ FILE: examples/domain_adaptation/re_identification/baseline.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import numpy as np import torch import torch.nn as nn from torch.nn import DataParallel import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.utils.data import DataLoader import utils from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss from tllib.vision.models.reid.identifier import ReIdentifier import tllib.vision.datasets.reid as datasets from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset from tllib.utils.scheduler import WarmupMultiStepLR from tllib.utils.metric.reid import validate, visualize_ranked_results from tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing, random_horizontal_flip=True, random_color_jitter=False, random_gray_scale=False, random_erasing=False) val_transform = utils.get_val_transform(args.height, args.width) print("train_transform: ", train_transform) print("val_transform: ", val_transform) working_dir = osp.dirname(osp.abspath(__file__)) source_root = osp.join(working_dir, args.source_root) target_root = osp.join(working_dir, args.target_root) # source dataset source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower())) sampler = RandomMultipleGallerySampler(source_dataset.train, args.num_instances) train_source_loader = DataLoader( convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform), batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) val_loader = DataLoader( convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)), root=source_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # target dataset target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower())) train_target_loader = DataLoader( convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=train_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True) train_target_iter = ForeverDataIterator(train_target_loader) test_loader = DataLoader( convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)), root=target_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # create model num_classes = source_dataset.num_train_pids backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device) model = DataParallel(model) # define optimizer and lr scheduler optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr, weight_decay=args.weight_decay) lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1, warmup_steps=args.warmup_steps) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # plot t-SNE utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model, filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device) # visualize ranked results visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device, visualize_dir=logger.visualize_directory, width=args.width, height=args.height, rerank=args.rerank) return if args.phase == 'test': print("Test on source domain:") validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("Test on target domain:") validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) return # define loss function criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device) criterion_triplet = SoftTripletLoss(margin=args.margin).to(device) # start training best_val_mAP = 0. best_test_mAP = 0. for epoch in range(args.epochs): # print learning rate print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args) # update learning rate lr_scheduler.step() if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1): # evaluate on validation set print("Validation on source domain...") _, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True) # remember best mAP and save checkpoint torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if val_mAP > best_val_mAP: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_mAP = max(val_mAP, best_val_mAP) # evaluate on test set print("Test on target domain...") _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) best_test_mAP = max(test_mAP, best_test_mAP) # evaluate on test set model.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) print("Test on target domain:") _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("test mAP on target = {}".format(test_mAP)) print("oracle mAP on target = {}".format(best_test_mAP)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, criterion_ce: CrossEntropyLossWithLabelSmooth, criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_ce = AverageMeter('CeLoss', ':3.2f') losses_triplet = AverageMeter('TripletLoss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, _, labels_s, _ = next(train_source_iter) x_t, _, _, _ = next(train_target_iter) x_s = x_s.to(device) x_t = x_t.to(device) labels_s = labels_s.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, f_s = model(x_s) y_t, f_t = model(x_t) # cross entropy loss loss_ce = criterion_ce(y_s, labels_s) # triplet loss loss_triplet = criterion_triplet(f_s, f_s, labels_s) loss = loss_ce + loss_triplet * args.trade_off cls_acc = accuracy(y_s, labels_s)[0] losses_ce.update(loss_ce.item(), x_s.size(0)) losses_triplet.update(loss_triplet.item(), x_s.size(0)) losses.update(loss.item(), x_s.size(0)) cls_accs.update(cls_acc.item(), x_s.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description="Baseline for Domain Adaptative ReID") # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', type=str, help='source domain') parser.add_argument('-t', '--target', type=str, help='target domain') parser.add_argument('--train-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: reid_resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--rate', type=float, default=0.2) # training parameters parser.add_argument('--trade-off', type=float, default=1, help='trade-off hyper parameter between cross entropy loss and triplet loss') parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard') parser.add_argument('-j', '--workers', type=int, default=4) parser.add_argument('-b', '--batch-size', type=int, default=16) parser.add_argument('--height', type=int, default=256, help="input height") parser.add_argument('--width', type=int, default=128, help="input width") parser.add_argument('--num-instances', type=int, default=4, help="each minibatch consist of " "(batch_size // num_instances) identities, and " "each identity has num_instances instances, " "default: 4") parser.add_argument('--lr', type=float, default=0.00035, help="initial learning rate") parser.add_argument('--weight-decay', type=float, default=5e-4) parser.add_argument('--epochs', type=int, default=80) parser.add_argument('--warmup-steps', type=int, default=10, help='number of warm-up steps') parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70], help='milestones for the learning rate decay') parser.add_argument('--eval-step', type=int, default=40) parser.add_argument('--iters-per-epoch', type=int, default=400) parser.add_argument('--print-freq', type=int, default=40) parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.') parser.add_argument('--rerank', action='store_true', help="evaluation only") parser.add_argument("--log", type=str, default='baseline', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/re_identification/baseline.sh ================================================ #!/usr/bin/env bash # Market1501 -> Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2Duke # Duke -> Market1501 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2Market # Market1501 -> MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMT # MSMT -> Market1501 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Market # Duke -> MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMT # MSMT -> Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Duke ================================================ FILE: examples/domain_adaptation/re_identification/baseline_cluster.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import DataParallel import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.utils.data import DataLoader from sklearn.cluster import KMeans, DBSCAN import utils import tllib.vision.datasets.reid as datasets from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset from tllib.vision.models.reid.identifier import ReIdentifier from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss from tllib.utils.metric.reid import extract_reid_feature, validate, visualize_ranked_results from tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing, random_horizontal_flip=True, random_color_jitter=False, random_gray_scale=False, random_erasing=True) val_transform = utils.get_val_transform(args.height, args.width) print("train_transform: ", train_transform) print("val_transform: ", val_transform) working_dir = osp.dirname(osp.abspath(__file__)) source_root = osp.join(working_dir, args.source_root) target_root = osp.join(working_dir, args.target_root) # source dataset source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower())) val_loader = DataLoader( convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)), root=source_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # target dataset target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower())) cluster_loader = DataLoader( convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) test_loader = DataLoader( convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)), root=target_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # create model num_classes = args.num_clusters backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device) model = DataParallel(model) # load pretrained weights pretrained_model = torch.load(args.pretrained_model_path) utils.copy_state_dict(model, pretrained_model) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') utils.copy_state_dict(model, checkpoint['model']) # analysis the model if args.phase == 'analysis': # plot t-SNE utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model, filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device) # visualize ranked results visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device, visualize_dir=logger.visualize_directory, width=args.width, height=args.height, rerank=args.rerank) return if args.phase == 'test': print("Test on Source domain:") validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("Test on target domain:") validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) return # define loss function criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device) criterion_triplet = SoftTripletLoss(margin=args.margin).to(device) # optionally resume from a checkpoint if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') utils.copy_state_dict(model, checkpoint['model']) args.start_epoch = checkpoint['epoch'] + 1 # start training best_test_mAP = 0. for epoch in range(args.start_epoch, args.epochs): # run clustering algorithm and generate pseudo labels if args.clustering_algorithm == 'kmeans': train_target_iter = run_kmeans(cluster_loader, model, target_dataset, train_transform, args) elif args.clustering_algorithm == 'dbscan': train_target_iter, num_classes = run_dbscan(cluster_loader, model, target_dataset, train_transform, args) # define cross entropy loss with current number of classes criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device) # define optimizer optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr, weight_decay=args.weight_decay) # train for one epoch train(train_target_iter, model, optimizer, criterion_ce, criterion_triplet, epoch, args) if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1): # remember best mAP and save checkpoint torch.save( { 'model': model.state_dict(), 'epoch': epoch }, logger.get_checkpoint_path(epoch) ) print("Test on target domain...") _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) if test_mAP > best_test_mAP: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_test_mAP = max(test_mAP, best_test_mAP) print("best mAP on target = {}".format(best_test_mAP)) logger.close() def run_kmeans(cluster_loader: DataLoader, model: DataParallel, target_dataset, train_transform, args: argparse.Namespace): # run kmeans clustering algorithm print('Clustering into {} classes'.format(args.num_clusters)) feature_dict = extract_reid_feature(cluster_loader, model, device, normalize=True) feature = torch.stack(list(feature_dict.values())).cpu().numpy() km = KMeans(n_clusters=args.num_clusters, random_state=args.seed).fit(feature) cluster_labels = km.labels_ cluster_centers = km.cluster_centers_ print('Clustering finished') # normalize cluster centers and convert to pytorch tensor cluster_centers = torch.from_numpy(cluster_centers).float().to(device) cluster_centers = F.normalize(cluster_centers, dim=1) # reinitialize classifier head model.module.head.weight.data.copy_(cluster_centers) # generate training set with pseudo labels target_train_set = [] for (fname, _, cid), label in zip(target_dataset.train, cluster_labels): target_train_set.append((fname, int(label), cid)) sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances) train_target_loader = DataLoader( convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir, transform=train_transform), batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True) train_target_iter = ForeverDataIterator(train_target_loader) return train_target_iter def run_dbscan(cluster_loader: DataLoader, model: DataParallel, target_dataset, train_transform, args: argparse.Namespace): # run dbscan clustering algorithm feature_dict = extract_reid_feature(cluster_loader, model, device, normalize=True) feature = torch.stack(list(feature_dict.values())).cpu() rerank_dist = utils.compute_rerank_dist(feature).numpy() print('Clustering with dbscan algorithm') dbscan = DBSCAN(eps=0.6, min_samples=4, metric='precomputed', n_jobs=-1) cluster_labels = dbscan.fit_predict(rerank_dist) print('Clustering finished') # generate training set with pseudo labels and calculate cluster centers target_train_set = [] cluster_centers = {} for i, ((fname, _, cid), label) in enumerate(zip(target_dataset.train, cluster_labels)): if label == -1: continue target_train_set.append((fname, label, cid)) if label not in cluster_centers: cluster_centers[label] = [] cluster_centers[label].append(feature[i]) cluster_centers = [torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys())] cluster_centers = torch.stack(cluster_centers) # normalize cluster centers cluster_centers = F.normalize(cluster_centers, dim=1).float().to(device) # reinitialize classifier head features_dim = model.module.features_dim num_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) model.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device) model.module.head.weight.data.copy_(cluster_centers) sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances) train_target_loader = DataLoader( convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir, transform=train_transform), batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True) train_target_iter = ForeverDataIterator(train_target_loader) return train_target_iter, num_clusters def train(train_target_iter: ForeverDataIterator, model, optimizer, criterion_ce: CrossEntropyLossWithLabelSmooth, criterion_triplet: SoftTripletLoss, epoch: int, args: argparse.Namespace): # train with pseudo labels batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_ce = AverageMeter('CeLoss', ':3.2f') losses_triplet = AverageMeter('TripletLoss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_t, _, labels_t, _ = next(train_target_iter) x_t = x_t.to(device) labels_t = labels_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_t, f_t = model(x_t) # cross entropy loss loss_ce = criterion_ce(y_t, labels_t) # triplet loss loss_triplet = criterion_triplet(f_t, f_t, labels_t) loss = loss_ce + loss_triplet * args.trade_off cls_acc = accuracy(y_t, labels_t)[0] losses_ce.update(loss_ce.item(), x_t.size(0)) losses_triplet.update(loss_triplet.item(), x_t.size(0)) losses.update(loss.item(), x_t.size(0)) cls_accs.update(cls_acc.item(), x_t.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description="Cluster Baseline for Domain Adaptative ReID") # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', type=str, help='source domain') parser.add_argument('-t', '--target', type=str, help='target domain') parser.add_argument('--train-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: reid_resnet50)') parser.add_argument('--num-clusters', type=int, default=500) parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--rate', type=float, default=0.2) # training parameters parser.add_argument('--clustering-algorithm', type=str, default='dbscan', choices=['kmeans', 'dbscan'], help='clustering algorithm to run, currently supported method: ["kmeans", "dbscan"]') parser.add_argument('--resume', type=str, default=None, help="Where restore model parameters from.") parser.add_argument('--pretrained-model-path', type=str, help='path to pretrained (source-only) model') parser.add_argument('--trade-off', type=float, default=1, help='trade-off hyper parameter between cross entropy loss and triplet loss') parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard') parser.add_argument('-j', '--workers', type=int, default=4) parser.add_argument('-b', '--batch-size', type=int, default=64) parser.add_argument('--height', type=int, default=256, help="input height") parser.add_argument('--width', type=int, default=128, help="input width") parser.add_argument('--num-instances', type=int, default=4, help="each minibatch consist of " "(batch_size // num_instances) identities, and " "each identity has num_instances instances, " "default: 4") parser.add_argument('--lr', type=float, default=0.00035, help="learning rate") parser.add_argument('--weight-decay', type=float, default=5e-4) parser.add_argument('--epochs', type=int, default=40) parser.add_argument('--start-epoch', default=0, type=int, help='start epoch') parser.add_argument('--eval-step', type=int, default=1) parser.add_argument('--iters-per-epoch', type=int, default=400) parser.add_argument('--print-freq', type=int, default=40) parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.') parser.add_argument('--rerank', action='store_true', help="evaluation only") parser.add_argument("--log", type=str, default='baseline_cluster', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/re_identification/baseline_cluster.sh ================================================ #!/usr/bin/env bash # Market1501 -> Duke # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2Duke # step2: train with pseudo labels assigned by cluster algorithm CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --pretrained-model-path logs/baseline/Market2Duke/checkpoints/best.pth \ --finetune --seed 0 --log logs/baseline_cluster/Market2Duke # Duke -> Market1501 # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2Market # step2: train with pseudo labels assigned by cluster algorithm CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \ --pretrained-model-path logs/baseline/Duke2Market/checkpoints/best.pth \ --finetune --seed 0 --log logs/baseline_cluster/Duke2Market # Market1501 -> MSMT # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMT # step2: train with pseudo labels assigned by cluster algorithm CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \ --pretrained-model-path logs/baseline/Market2MSMT/checkpoints/best.pth \ --num-clusters 1000 --finetune --seed 0 --log logs/baseline_cluster/Market2MSMT # MSMT -> Market1501 # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Market # step2: train with pseudo labels assigned by cluster algorithm CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \ --pretrained-model-path logs/baseline/MSMT2Market/checkpoints/best.pth \ --finetune --seed 0 --log logs/baseline_cluster/MSMT2Market # Duke -> MSMT # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMT # step2: train with pseudo labels assigned by cluster algorithm CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ --pretrained-model-path logs/baseline/Duke2MSMT/checkpoints/best.pth \ --num-clusters 1000 --finetune --seed 0 --log logs/baseline_cluster/Duke2MSMT # MSMT -> Duke # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2Duke # step2: train with pseudo labels assigned by cluster algorithm CUDA_VISIBLE_DEVICES=0,1,2,3 python baseline_cluster.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ --pretrained-model-path logs/baseline/MSMT2Duke/checkpoints/best.pth \ --finetune --seed 0 --log logs/baseline_cluster/MSMT2Duke ================================================ FILE: examples/domain_adaptation/re_identification/ibn.sh ================================================ #!/usr/bin/env bash # Market1501 -> Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a resnet50_ibn_a \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a resnet50_ibn_b \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2Duke # Duke -> Market1501 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a resnet50_ibn_a \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2Market CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a resnet50_ibn_b \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2Market # Market1501 -> MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a resnet50_ibn_a \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a resnet50_ibn_b \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Market2MSMT # MSMT -> Market1501 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a resnet50_ibn_a \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Market CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a resnet50_ibn_b \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Market # Duke -> MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a resnet50_ibn_a \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a resnet50_ibn_b \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/Duke2MSMT # MSMT -> Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a resnet50_ibn_a \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a resnet50_ibn_b \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/ibn/MSMT2Duke ================================================ FILE: examples/domain_adaptation/re_identification/mmt.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import os.path as osp import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import DataParallel import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.utils.data import DataLoader from sklearn.cluster import KMeans, DBSCAN import utils import tllib.vision.datasets.reid as datasets from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset from tllib.vision.models.reid.identifier import ReIdentifier from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss, CrossEntropyLoss from tllib.self_training.mean_teacher import EMATeacher from tllib.vision.transforms import MultipleApply from tllib.utils.metric.reid import extract_reid_feature, validate, visualize_ranked_results from tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing, random_horizontal_flip=True, random_color_jitter=False, random_gray_scale=False, random_erasing=True) val_transform = utils.get_val_transform(args.height, args.width) print("train_transform: ", train_transform) print("val_transform: ", val_transform) working_dir = osp.dirname(osp.abspath(__file__)) source_root = osp.join(working_dir, args.source_root) target_root = osp.join(working_dir, args.target_root) # source dataset source_dataset = datasets.__dict__[args.source](root=osp.join(source_root, args.source.lower())) val_loader = DataLoader( convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)), root=source_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # target dataset target_dataset = datasets.__dict__[args.target](root=osp.join(target_root, args.target.lower())) cluster_loader = DataLoader( convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) test_loader = DataLoader( convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)), root=target_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # create model model_1, model_1_ema = create_model(args, args.pretrained_model_1_path) model_2, model_2_ema = create_model(args, args.pretrained_model_2_path) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') utils.copy_state_dict(model_1_ema, checkpoint) # analysis the model if args.phase == 'analysis': # plot t-SNE utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model_1_ema, filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device) # visualize ranked results visualize_ranked_results(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery, device, visualize_dir=logger.visualize_directory, width=args.width, height=args.height, rerank=args.rerank) return if args.phase == 'test': print("Test on Source domain:") validate(val_loader, model_1_ema, source_dataset.query, source_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("Test on target domain:") validate(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) return # define loss function num_classes = args.num_clusters criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device) criterion_ce_soft = CrossEntropyLoss().to(device) criterion_triplet = SoftTripletLoss(margin=0.0).to(device) criterion_triplet_soft = SoftTripletLoss(margin=None).to(device) # optionally resume from a checkpoint if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') utils.copy_state_dict(model_1, checkpoint['model_1']) utils.copy_state_dict(model_1_ema, checkpoint['model_1_ema']) utils.copy_state_dict(model_2, checkpoint['model_2']) utils.copy_state_dict(model_2_ema, checkpoint['model_2_ema']) args.start_epoch = checkpoint['epoch'] + 1 # start training best_test_mAP = 0. for epoch in range(args.start_epoch, args.epochs): # run clustering algorithm and generate pseudo labels if args.clustering_algorithm == 'kmeans': train_target_iter = run_kmeans(cluster_loader, model_1, model_2, model_1_ema, model_2_ema, target_dataset, train_transform, args) elif args.clustering_algorithm == 'dbscan': train_target_iter, num_classes = run_dbscan(cluster_loader, model_1, model_2, model_1_ema, model_2_ema, target_dataset, train_transform, args) # define cross entropy loss with current number of classes criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device) # define optimizer optimizer = Adam(model_1.module.get_parameters(base_lr=args.lr, rate=args.rate) + model_2.module.get_parameters( base_lr=args.lr, rate=args.rate), args.lr, weight_decay=args.weight_decay) # train for one epoch train(train_target_iter, model_1, model_1_ema, model_2, model_2_ema, optimizer, criterion_ce, criterion_ce_soft, criterion_triplet, criterion_triplet_soft, epoch, args) if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1): # save checkpoint and remember best mAP torch.save( { 'model_1': model_1.state_dict(), 'model_1_ema': model_1_ema.state_dict(), 'model_2': model_2.state_dict(), 'model_2_ema': model_2_ema.state_dict(), 'epoch': epoch }, logger.get_checkpoint_path(epoch) ) print("Test model_1 on target domain...") _, test_mAP_1 = validate(test_loader, model_1_ema, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("Test model_2 on target domain...") _, test_mAP_2 = validate(test_loader, model_2_ema, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) if test_mAP_1 > test_mAP_2 and test_mAP_1 > best_test_mAP: torch.save(model_1_ema.state_dict(), logger.get_checkpoint_path('best')) best_test_mAP = test_mAP_1 if test_mAP_2 > test_mAP_1 and test_mAP_2 > best_test_mAP: torch.save(model_2_ema.state_dict(), logger.get_checkpoint_path('best')) best_test_mAP = test_mAP_2 print("best mAP on target = {}".format(best_test_mAP)) logger.close() def create_model(args: argparse.Namespace, pretrained_model_path: str): num_classes = args.num_clusters backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device) model = DataParallel(model) # load pretrained weights pretrained_model = torch.load(pretrained_model_path) utils.copy_state_dict(model, pretrained_model) # EMA model model_ema = EMATeacher(model, args.alpha) return model, model_ema def run_kmeans(cluster_loader: DataLoader, model_1: DataParallel, model_2: DataParallel, model_1_ema: EMATeacher, model_2_ema: EMATeacher, target_dataset, train_transform, args: argparse.Namespace): # run kmeans clustering algorithm print('Clustering into {} classes'.format(args.num_clusters)) # collect feature with different ema teachers feature_dict_1 = extract_reid_feature(cluster_loader, model_1_ema, device, normalize=True) feature_1 = torch.stack(list(feature_dict_1.values())).cpu().numpy() feature_dict_2 = extract_reid_feature(cluster_loader, model_2_ema, device, normalize=True) feature_2 = torch.stack(list(feature_dict_2.values())).cpu().numpy() # average feature_1, feature_2 to create final feature feature = (feature_1 + feature_2) / 2 km = KMeans(n_clusters=args.num_clusters, random_state=args.seed).fit(feature) cluster_labels = km.labels_ cluster_centers = km.cluster_centers_ print('Clustering finished') # normalize cluster centers and convert to pytorch tensor cluster_centers = torch.from_numpy(cluster_centers).float().to(device) cluster_centers = F.normalize(cluster_centers, dim=1) # reinitialize classifier head model_1.module.head.weight.data.copy_(cluster_centers) model_2.module.head.weight.data.copy_(cluster_centers) model_1_ema.module.head.weight.data.copy_(cluster_centers) model_2_ema.module.head.weight.data.copy_(cluster_centers) # generate training set with pseudo labels target_train_set = [] for (fname, _, cid), label in zip(target_dataset.train, cluster_labels): target_train_set.append((fname, int(label), cid)) sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances) train_target_loader = DataLoader( convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir, transform=MultipleApply([train_transform, train_transform])), batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True) train_target_iter = ForeverDataIterator(train_target_loader) return train_target_iter def run_dbscan(cluster_loader: DataLoader, model_1: DataParallel, model_2: DataParallel, model_1_ema: EMATeacher, model_2_ema: EMATeacher, target_dataset, train_transform, args: argparse.Namespace): # run dbscan clustering algorithm # collect feature with different ema teachers feature_dict_1 = extract_reid_feature(cluster_loader, model_1_ema, device, normalize=True) feature_1 = torch.stack(list(feature_dict_1.values())).cpu() feature_dict_2 = extract_reid_feature(cluster_loader, model_2_ema, device, normalize=True) feature_2 = torch.stack(list(feature_dict_2.values())).cpu() # average feature_1, feature_2 to create final feature feature = (feature_1 + feature_2) / 2 feature = F.normalize(feature, dim=1) rerank_dist = utils.compute_rerank_dist(feature).numpy() print('Clustering with dbscan algorithm') dbscan = DBSCAN(eps=0.7, min_samples=4, metric='precomputed', n_jobs=-1) cluster_labels = dbscan.fit_predict(rerank_dist) print('Clustering finished') # generate training set with pseudo labels and calculate cluster centers target_train_set = [] cluster_centers = {} for i, ((fname, _, cid), label) in enumerate(zip(target_dataset.train, cluster_labels)): if label == -1: continue target_train_set.append((fname, label, cid)) if label not in cluster_centers: cluster_centers[label] = [] cluster_centers[label].append(feature[i]) cluster_centers = [torch.stack(cluster_centers[idx]).mean(0) for idx in sorted(cluster_centers.keys())] cluster_centers = torch.stack(cluster_centers) # normalize cluster centers cluster_centers = F.normalize(cluster_centers, dim=1).float().to(device) # reinitialize classifier head features_dim = model_1.module.features_dim num_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0) model_1.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device) model_2.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device) model_1_ema.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device) model_2_ema.module.head = nn.Linear(features_dim, num_clusters, bias=False).to(device) model_1.module.head.weight.data.copy_(cluster_centers) model_2.module.head.weight.data.copy_(cluster_centers) model_1_ema.module.head.weight.data.copy_(cluster_centers) model_2_ema.module.head.weight.data.copy_(cluster_centers) sampler = RandomMultipleGallerySampler(target_train_set, args.num_instances) train_target_loader = DataLoader( convert_to_pytorch_dataset(target_train_set, root=target_dataset.images_dir, transform=MultipleApply([train_transform, train_transform])), batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True) train_target_iter = ForeverDataIterator(train_target_loader) return train_target_iter, num_clusters def train(train_target_iter: ForeverDataIterator, model_1: DataParallel, model_1_ema: EMATeacher, model_2: DataParallel, model_2_ema: EMATeacher, optimizer: Adam, criterion_ce: CrossEntropyLossWithLabelSmooth, criterion_ce_soft: CrossEntropyLoss, criterion_triplet: SoftTripletLoss, criterion_triplet_soft: SoftTripletLoss, epoch: int, args: argparse.Namespace): # train with pseudo labels batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') # statistics for model_1 losses_ce_1 = AverageMeter('Model_1 CELoss', ':3.2f') losses_triplet_1 = AverageMeter('Model_1 TripletLoss', ':3.2f') cls_accs_1 = AverageMeter('Model_1 Cls Acc', ':3.1f') # statistics for model_2 losses_ce_2 = AverageMeter('Model_2 CELoss', ':3.2f') losses_triplet_2 = AverageMeter('Model_2 TripletLoss', ':3.2f') cls_accs_2 = AverageMeter('Model_2 Cls Acc', ':3.1f') losses_ce_soft = AverageMeter('Soft CELoss', ':3.2f') losses_triplet_soft = AverageMeter('Soft TripletLoss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_ce_1, losses_triplet_1, cls_accs_1, losses_ce_2, losses_triplet_2, cls_accs_2, losses_ce_soft, losses_triplet_soft, losses], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model_1.train() model_2.train() model_1_ema.train() model_2_ema.train() end = time.time() for i in range(args.iters_per_epoch): # below we ignore subscript `t` and use `x_1`, `x_2` to denote different augmented versions of origin samples # `x_t` from target domain (x_1, x_2), _, labels, _ = next(train_target_iter) x_1 = x_1.to(device) x_2 = x_2.to(device) labels = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_1, f_1 = model_1(x_1) y_2, f_2 = model_2(x_2) # compute output by ema-teacher y_1_teacher, f_1_teacher = model_1_ema(x_1) y_2_teacher, f_2_teacher = model_2_ema(x_2) # cross entropy loss loss_ce_1 = criterion_ce(y_1, labels) loss_ce_2 = criterion_ce(y_2, labels) # triplet loss loss_triplet_1 = criterion_triplet(f_1, f_1, labels) loss_triplet_2 = criterion_triplet(f_2, f_2, labels) # soft cross entropy loss loss_ce_soft = criterion_ce_soft(y_1, y_2_teacher) + \ criterion_ce_soft(y_2, y_1_teacher) # soft triplet loss loss_triplet_soft = criterion_triplet_soft(f_1, f_2_teacher, labels) + \ criterion_triplet_soft(f_2, f_1_teacher, labels) # final objective loss = (loss_ce_1 + loss_ce_2) * (1 - args.trade_off_ce_soft) + \ (loss_triplet_1 + loss_triplet_2) * (1 - args.trade_off_triplet_soft) + \ loss_ce_soft * args.trade_off_ce_soft + \ loss_triplet_soft * args.trade_off_triplet_soft # update statistics batch_size = args.batch_size cls_acc_1 = accuracy(y_1, labels)[0] cls_acc_2 = accuracy(y_2, labels)[0] # model 1 losses_ce_1.update(loss_ce_1.item(), batch_size) losses_triplet_1.update(loss_triplet_1.item(), batch_size) cls_accs_1.update(cls_acc_1.item(), batch_size) # model 2 losses_ce_2.update(loss_ce_2.item(), batch_size) losses_triplet_2.update(loss_triplet_2.item(), batch_size) cls_accs_2.update(cls_acc_2.item(), batch_size) losses_ce_soft.update(loss_ce_soft.item(), batch_size) losses_triplet_soft.update(loss_triplet_soft.item(), batch_size) losses.update(loss.item(), batch_size) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # update teacher global_step = epoch * args.iters_per_epoch + i + 1 model_1_ema.set_alpha(min(args.alpha, 1 - 1 / global_step)) model_2_ema.set_alpha(min(args.alpha, 1 - 1 / global_step)) model_1_ema.update() model_2_ema.update() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description="MMT for Domain Adaptative ReID") # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', type=str, help='source domain') parser.add_argument('-t', '--target', type=str, help='target domain') parser.add_argument('--train-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: reid_resnet50)') parser.add_argument('--num-clusters', type=int, default=500) parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--alpha', type=float, default=0.999, help='ema alpha') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--rate', type=float, default=0.2) # training parameters parser.add_argument('--clustering-algorithm', type=str, default='dbscan', choices=['kmeans', 'dbscan'], help='clustering algorithm to run, currently supported method: ["kmeans", "dbscan"]') parser.add_argument('--resume', type=str, default=None, help="Where restore model parameters from.") parser.add_argument('--pretrained-model-1-path', type=str, help='path to pretrained (source-only) model_1') parser.add_argument('--pretrained-model-2-path', type=str, help='path to pretrained (source-only) model_2') parser.add_argument('--trade-off-ce-soft', type=float, default=0.5, help='the trade off hyper parameter between cross entropy loss and soft cross entropy loss') parser.add_argument('--trade-off-triplet-soft', type=float, default=0.8, help='the trade off hyper parameter between triplet loss and soft triplet loss') parser.add_argument('-j', '--workers', type=int, default=4) parser.add_argument('-b', '--batch-size', type=int, default=64) parser.add_argument('--height', type=int, default=256, help="input height") parser.add_argument('--width', type=int, default=128, help="input width") parser.add_argument('--num-instances', type=int, default=4, help="each minibatch consist of " "(batch_size // num_instances) identities, and " "each identity has num_instances instances, " "default: 4") parser.add_argument('--lr', type=float, default=0.00035, help="learning rate") parser.add_argument('--weight-decay', type=float, default=5e-4) parser.add_argument('--epochs', type=int, default=40) parser.add_argument('--start-epoch', default=0, type=int, help='start epoch') parser.add_argument('--eval-step', type=int, default=1) parser.add_argument('--iters-per-epoch', type=int, default=400) parser.add_argument('--print-freq', type=int, default=40) parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.') parser.add_argument('--rerank', action='store_true', help="evaluation only") parser.add_argument("--log", type=str, default='mmt', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/re_identification/mmt.sh ================================================ #!/usr/bin/env bash # Market1501 -> Duke # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2DukeSeed0 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2DukeSeed1 # step2: train mmt CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --pretrained-model-1-path logs/baseline/Market2DukeSeed0/checkpoints/best.pth \ --pretrained-model-2-path logs/baseline/Market2DukeSeed1/checkpoints/best.pth \ --finetune --seed 0 --log logs/mmt/Market2Duke # Duke -> Market1501 # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MarketSeed0 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Duke2MarketSeed1 # step2: train mmt CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s DukeMTMC -t Market1501 -a reid_resnet50 \ --pretrained-model-1-path logs/baseline/Duke2MarketSeed0/checkpoints/best.pth \ --pretrained-model-2-path logs/baseline/Duke2MarketSeed1/checkpoints/best.pth \ --finetune --seed 0 --log logs/mmt/Duke2Market # Market1501 -> MSMT # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Market2MSMTSeed0 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Market2MSMTSeed1 # step2: train mmt CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s Market1501 -t MSMT17 -a reid_resnet50 \ --pretrained-model-1-path logs/baseline/Market2MSMTSeed0/checkpoints/best.pth \ --pretrained-model-2-path logs/baseline/Market2MSMTSeed1/checkpoints/best.pth \ --num-clusters 1000 --finetune --seed 0 --log logs/mmt/Market2MSMT # MSMT -> Market1501 # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2MarketSeed0 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/MSMT2MarketSeed1 # step2: train mmt CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s MSMT17 -t Market1501 -a reid_resnet50 \ --pretrained-model-1-path logs/baseline/MSMT2MarketSeed0/checkpoints/best.pth \ --pretrained-model-2-path logs/baseline/MSMT2MarketSeed1/checkpoints/best.pth \ --finetune --seed 0 --log logs/mmt/MSMT2Market # Duke -> MSMT # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/Duke2MSMTSeed0 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/Duke2MSMTSeed1 # step2: train mmt CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ --pretrained-model-1-path logs/baseline/Duke2MSMTSeed0/checkpoints/best.pth \ --pretrained-model-2-path logs/baseline/Duke2MSMTSeed1/checkpoints/best.pth \ --num-clusters 1000 --finetune --seed 0 --log logs/mmt/Duke2MSMT # MSMT -> Duke # step1: pretrain CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/baseline/MSMT2DukeSeed0 CUDA_VISIBLE_DEVICES=0 python baseline.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 1 --log logs/baseline/MSMT2DukeSeed1 # step2: train mmt CUDA_VISIBLE_DEVICES=0,1,2,3 python mmt.py data data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ --pretrained-model-1-path logs/baseline/MSMT2DukeSeed0/checkpoints/best.pth \ --pretrained-model-2-path logs/baseline/MSMT2DukeSeed1/checkpoints/best.pth \ --finetune --seed 0 --log logs/mmt/MSMT2Duke ================================================ FILE: examples/domain_adaptation/re_identification/requirements.txt ================================================ timm opencv-python ================================================ FILE: examples/domain_adaptation/re_identification/spgan.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import sys import argparse import itertools import os.path as osp from PIL import Image import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import torchvision.transforms as T sys.path.append('../../..') import tllib.translation.cyclegan as cyclegan import tllib.translation.spgan as spgan from tllib.translation.cyclegan.util import ImagePool, set_requires_grad import tllib.vision.datasets.reid as datasets from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset from tllib.vision.transforms import Denormalize from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = T.Compose([ T.Resize(args.load_size, Image.BICUBIC), T.RandomCrop(args.input_size), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) working_dir = osp.dirname(osp.abspath(__file__)) root = osp.join(working_dir, args.root) source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower())) train_source_loader = DataLoader( convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower())) train_target_loader = DataLoader( convert_to_pytorch_dataset(target_dataset.train, root=target_dataset.images_dir, transform=train_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=True, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # define networks (generators, discriminators and siamese network) netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device) netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device) siamese_net = spgan.SiameseNetwork(nsf=args.nsf).to(device) # create image buffer to store previously generated images fake_S_pool = ImagePool(args.pool_size) fake_T_pool = ImagePool(args.pool_size) # define optimizer and lr scheduler optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_siamese = Adam(siamese_net.parameters(), lr=args.lr, betas=(args.beta1, 0.999)) lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay) lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function) lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function) lr_scheduler_siamese = LambdaLR(optimizer_siamese, lr_lambda=lr_decay_function) # optionally resume from a checkpoint if args.resume: print("Resume from", args.resume) checkpoint = torch.load(args.resume, map_location='cpu') netG_S2T.load_state_dict(checkpoint['netG_S2T']) netG_T2S.load_state_dict(checkpoint['netG_T2S']) netD_S.load_state_dict(checkpoint['netD_S']) netD_T.load_state_dict(checkpoint['netD_T']) siamese_net.load_state_dict(checkpoint['siamese_net']) optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D.load_state_dict(checkpoint['optimizer_D']) optimizer_siamese.load_state_dict(checkpoint['optimizer_siamese']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D']) lr_scheduler_siamese.load_state_dict(checkpoint['lr_scheduler_siamese']) args.start_epoch = checkpoint['epoch'] + 1 if args.phase == 'test': transform = T.Compose([ T.Resize(args.test_input_size, Image.BICUBIC), cyclegan.transform.Translation(netG_S2T, device) ]) source_dataset.translate(transform, osp.join(args.translated_root, args.source.lower())) return # define loss function criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss() criterion_cycle = nn.L1Loss() criterion_identity = nn.L1Loss() criterion_contrastive = spgan.ContrastiveLoss(margin=args.margin) # define visualization function tensor_to_image = T.Compose([ Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), T.ToPILImage() ]) def visualize(image, name): """ Args: image (tensor): image in shape 3 x H x W name: name of the saving image """ tensor_to_image(image).save(logger.get_image_path("{}.png".format(name))) # start training for epoch in range(args.start_epoch, args.epochs + args.epochs_decay): logger.set_epoch(epoch) print(lr_scheduler_G.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, siamese_net, criterion_gan, criterion_cycle, criterion_identity, criterion_contrastive, optimizer_G, optimizer_D, optimizer_siamese, fake_S_pool, fake_T_pool, epoch, visualize, args) # update learning rates lr_scheduler_G.step() lr_scheduler_D.step() lr_scheduler_siamese.step() # save checkpoint torch.save( { 'netG_S2T': netG_S2T.state_dict(), 'netG_T2S': netG_T2S.state_dict(), 'netD_S': netD_S.state_dict(), 'netD_T': netD_T.state_dict(), 'siamese_net': siamese_net.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D': optimizer_D.state_dict(), 'optimizer_siamese': optimizer_siamese.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D': lr_scheduler_D.state_dict(), 'lr_scheduler_siamese': lr_scheduler_siamese.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if args.translated_root is not None: transform = T.Compose([ T.Resize(args.test_input_size, Image.BICUBIC), cyclegan.transform.Translation(netG_S2T, device) ]) source_dataset.translate(transform, osp.join(args.translated_root, args.source.lower())) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, netG_S2T, netG_T2S, netD_S, netD_T, siamese_net: spgan.SiameseNetwork, criterion_gan: cyclegan.LeastSquaresGenerativeAdversarialLoss, criterion_cycle: nn.L1Loss, criterion_identity: nn.L1Loss, criterion_contrastive: spgan.ContrastiveLoss, optimizer_G: Adam, optimizer_D: Adam, optimizer_siamese: Adam, fake_S_pool: ImagePool, fake_T_pool: ImagePool, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_G_S2T = AverageMeter('G_S2T', ':3.2f') losses_G_T2S = AverageMeter('G_T2S', ':3.2f') losses_D_S = AverageMeter('D_S', ':3.2f') losses_D_T = AverageMeter('D_T', ':3.2f') losses_cycle_S = AverageMeter('cycle_S', ':3.2f') losses_cycle_T = AverageMeter('cycle_T', ':3.2f') losses_identity_S = AverageMeter('idt_S', ':3.2f') losses_identity_T = AverageMeter('idt_T', ':3.2f') losses_contrastive_G = AverageMeter('contrastive_G', ':3.2f') losses_contrastive_siamese = AverageMeter('contrastive_siamese', ':3.2f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T, losses_contrastive_G, losses_contrastive_siamese], prefix="Epoch: [{}]".format(epoch)) end = time.time() for i in range(args.iters_per_epoch): real_S, _, _, _ = next(train_source_iter) real_T, _, _, _ = next(train_target_iter) real_S = real_S.to(device) real_T = real_T.to(device) # measure data loading time data_time.update(time.time() - end) # Compute fake images and reconstruction images. fake_T = netG_S2T(real_S) rec_S = netG_T2S(fake_T) fake_S = netG_T2S(real_T) rec_T = netG_S2T(fake_S) # =============================================== # train the generators (every two iterations) # =============================================== if i % 2 == 0: # save memory set_requires_grad(netD_S, False) set_requires_grad(netD_T, False) set_requires_grad(siamese_net, False) # GAN loss D_T(G_S2T(S)) loss_G_S2T = criterion_gan(netD_T(fake_T), real=True) # GAN loss D_S(G_T2S(B)) loss_G_T2S = criterion_gan(netD_S(fake_S), real=True) # Cycle loss || G_T2S(G_S2T(S)) - S|| loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle # Cycle loss || G_S2T(G_T2S(T)) - T|| loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle # Identity loss # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T|| identity_T = netG_S2T(real_T) loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S|| identity_S = netG_T2S(real_S) loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity # siamese network output f_real_S = siamese_net(real_S) f_fake_T = siamese_net(fake_T) f_real_T = siamese_net(real_T) f_fake_S = siamese_net(fake_S) # positive pair loss_contrastive_p_G = criterion_contrastive(f_real_S, f_fake_T, 0) + \ criterion_contrastive(f_real_T, f_fake_S, 0) # negative pair loss_contrastive_n_G = criterion_contrastive(f_fake_T, f_real_T, 1) + \ criterion_contrastive(f_fake_S, f_real_S, 1) + \ criterion_contrastive(f_real_S, f_real_T, 1) # contrastive loss loss_contrastive_G = (loss_contrastive_p_G + 0.5 * loss_contrastive_n_G) / 4 * args.trade_off_contrastive # combined loss and calculate gradients loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T if epoch > 1: loss_G += loss_contrastive_G netG_S2T.zero_grad() netG_T2S.zero_grad() loss_G.backward() optimizer_G.step() # update corresponding statistics losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0)) losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0)) losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0)) losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0)) losses_identity_S.update(loss_identity_S.item(), real_S.size(0)) losses_identity_T.update(loss_identity_T.item(), real_S.size(0)) if epoch > 1: losses_contrastive_G.update(loss_contrastive_G, real_S.size(0)) # =============================================== # train the siamese network (when epoch > 0) # =============================================== if epoch > 0: set_requires_grad(siamese_net, True) # siamese network output f_real_S = siamese_net(real_S) f_fake_T = siamese_net(fake_T.detach()) f_real_T = siamese_net(real_T) f_fake_S = siamese_net(fake_S.detach()) # positive pair loss_contrastive_p_siamese = criterion_contrastive(f_real_S, f_fake_T, 0) + \ criterion_contrastive(f_real_T, f_fake_S, 0) # negative pair loss_contrastive_n_siamese = criterion_contrastive(f_real_S, f_real_T, 1) # contrastive loss loss_contrastive_siamese = (loss_contrastive_p_siamese + 2 * loss_contrastive_n_siamese) / 3 # update siamese network siamese_net.zero_grad() loss_contrastive_siamese.backward() optimizer_siamese.step() # update corresponding statistics losses_contrastive_siamese.update(loss_contrastive_siamese, real_S.size(0)) # =============================================== # train the discriminators # =============================================== set_requires_grad(netD_S, True) set_requires_grad(netD_T, True) # Calculate GAN loss for discriminator D_S fake_S_ = fake_S_pool.query(fake_S.detach()) loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False)) # Calculate GAN loss for discriminator D_T fake_T_ = fake_T_pool.query(fake_T.detach()) loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False)) # update discriminators netD_S.zero_grad() netD_T.zero_grad() loss_D_S.backward() loss_D_T.backward() optimizer_D.step() # update corresponding statistics losses_D_S.update(loss_D_S.item(), real_S.size(0)) losses_D_T.update(loss_D_T.item(), real_S.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T], ["real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T", "identity_S", "identity_T"]): visualize(tensor[0], "{}_{}".format(i, name)) if __name__ == '__main__': dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='SPGAN for Domain Adaptative ReID') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-s', '--source', type=str, help='source domain') parser.add_argument('-t', '--target', type=str, help='target domain') parser.add_argument('--load-size', nargs='+', type=int, default=(286, 144), help='loading image size') parser.add_argument('--input-size', nargs='+', type=int, default=(256, 128), help='the input and output image size during training') # model parameters parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--nsf', type=int, default=64, help='# of sianet filters int the first conv layer') parser.add_argument('--netD', type=str, default='patch', help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.') parser.add_argument('--netG', type=str, default='resnet_9', help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]') parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') # training parameters parser.add_argument("--resume", type=str, default=None, help="Where restore model parameters from.") parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss') parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss') parser.add_argument('--trade-off-contrastive', type=float, default=2.0, help='trade off for contrastive loss') parser.add_argument('--margin', type=float, default=2, help='margin for contrastive loss') parser.add_argument('-b', '--batch-size', default=8, type=int, metavar='N', help='mini-batch size (default: 8)') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=15, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--epochs-decay', type=int, default=15, help='number of epochs to linearly decay learning rate to zero') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('-i', '--iters-per-epoch', default=2000, type=int, help='Number of iterations per epoch') parser.add_argument('--pool-size', type=int, default=50, help='the size of image buffer that stores previously generated images') parser.add_argument('-p', '--print-freq', default=500, type=int, metavar='N', help='print frequency (default: 500)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='spgan', help="Where to save logs, checkpoints and debugging images.") # test parameters parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--translated-root', type=str, default=None, help="The root to put the translated dataset") parser.add_argument('--test-input-size', nargs='+', type=int, default=(256, 128), help='the input image size during testing') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/re_identification/spgan.sh ================================================ # Market1501 -> Duke # step1: train SPGAN CUDA_VISIBLE_DEVICES=0 python spgan.py data -s Market1501 -t DukeMTMC \ --log logs/spgan/Market2Duke --translated-root data/spganM2D --seed 0 # step2: train baseline on translated source dataset CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganM2D data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Market2Duke # Duke -> Market1501 # step1: train SPGAN CUDA_VISIBLE_DEVICES=0 python spgan.py data -s DukeMTMC -t Market1501 \ --log logs/spgan/Duke2Market --translated-root data/spganD2M --seed 0 # step2: train baseline on translated source dataset CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganD2M data -s DukeMTMC -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Duke2Market # Market1501 -> MSMT17 # step1: train SPGAN CUDA_VISIBLE_DEVICES=0 python spgan.py data -s Market1501 -t MSMT17 \ --log logs/spgan/Market2MSMT --translated-root data/spganM2S --seed 0 # step2: train baseline on translated source dataset CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganM2S data -s Market1501 -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Market2MSMT # MSMT -> Market1501 # step1: train SPGAN CUDA_VISIBLE_DEVICES=0 python spgan.py data -s MSMT17 -t Market1501 \ --log logs/spgan/MSMT2Market --translated-root data/spganS2M --seed 0 # step2: train baseline on translated source dataset CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganS2M data -s MSMT17 -t Market1501 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/MSMT2Market # Duke -> MSMT # step1: train SPGAN CUDA_VISIBLE_DEVICES=0 python spgan.py data -s DukeMTMC -t MSMT17 \ --log logs/spgan/Duke2MSMT --translated-root data/spganD2S --seed 0 # step2: train baseline on translated source dataset CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganD2S data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/Duke2MSMT # MSMT -> Duke # step1: train SPGAN CUDA_VISIBLE_DEVICES=0 python spgan.py data -s MSMT17 -t DukeMTMC \ --log logs/spgan/MSMT2Duke --translated-root data/spganS2D --seed 0 # step2: train baseline on translated source dataset CUDA_VISIBLE_DEVICES=0 python baseline.py data/spganS2D data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ --iters-per-epoch 800 --print-freq 80 --finetune --seed 0 --log logs/spgan/MSMT2Duke ================================================ FILE: examples/domain_adaptation/re_identification/utils.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import sys import timm import numpy as np import torch import torch.nn as nn from torch.nn import Parameter import torchvision.transforms as T sys.path.append('../../..') from tllib.utils.metric.reid import extract_reid_feature from tllib.utils.analysis import tsne from tllib.vision.transforms import RandomErasing import tllib.vision.models.reid as models import tllib.normalization.ibn as ibn_models def copy_state_dict(model, state_dict, strip=None): """Copy state dict into the passed in ReID model. As we are using classification loss, which means we need to output different number of classes(identities) for different datasets, we will not copy the parameters of last `fc` layer. """ tgt_state = model.state_dict() copied_names = set() for name, param in state_dict.items(): if strip is not None and name.startswith(strip): name = name[len(strip):] if name not in tgt_state: continue if isinstance(param, Parameter): param = param.data if param.size() != tgt_state[name].size(): print('mismatch:', name, param.size(), tgt_state[name].size()) continue tgt_state[name].copy_(param) copied_names.add(name) missing = set(tgt_state.keys()) - copied_names if len(missing) > 0: print("missing keys in state_dict:", missing) return model def get_model_names(): return sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) + \ sorted(name for name in ibn_models.__dict__ if name.islower() and not name.startswith("__") and callable(ibn_models.__dict__[name])) + \ timm.list_models() def get_model(model_name): if model_name in models.__dict__: # load models from tllib.vision.models backbone = models.__dict__[model_name](pretrained=True) elif model_name in ibn_models.__dict__: # load models (with ibn) from tllib.normalization.ibn backbone = ibn_models.__dict__[model_name](pretrained=True) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=True) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() return backbone def get_train_transform(height, width, resizing='default', random_horizontal_flip=True, random_color_jitter=False, random_gray_scale=False, random_erasing=False): """ resizing mode: - default: resize the image to (height, width), zero-pad it by 10 on each size, the take a random crop of (height, width) - res: resize the image to(height, width) """ if resizing == 'default': transform = T.Compose([ T.Resize((height, width), interpolation=3), T.Pad(10), T.RandomCrop((height, width)) ]) elif resizing == 'res': transform = T.Resize((height, width), interpolation=3) else: raise NotImplementedError(resizing) transforms = [transform] if random_horizontal_flip: transforms.append(T.RandomHorizontalFlip()) if random_color_jitter: transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)) if random_gray_scale: transforms.append(T.RandomGrayscale()) transforms.extend([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) if random_erasing: transforms.append(RandomErasing(probability=0.5, mean=[0.485, 0.456, 0.406])) return T.Compose(transforms) def get_val_transform(height, width): return T.Compose([ T.Resize((height, width), interpolation=3), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def visualize_tsne(source_loader, target_loader, model, filename, device, n_data_points_per_domain=3000): """Visualize features from different domains using t-SNE. As we can have very large number of samples in each domain, only `n_data_points_per_domain` number of samples are randomly selected in each domain. """ source_feature_dict = extract_reid_feature(source_loader, model, device, normalize=True) source_feature = torch.stack(list(source_feature_dict.values())).cpu() source_feature = source_feature[torch.randperm(len(source_feature))] source_feature = source_feature[:n_data_points_per_domain] target_feature_dict = extract_reid_feature(target_loader, model, device, normalize=True) target_feature = torch.stack(list(target_feature_dict.values())).cpu() target_feature = target_feature[torch.randperm(len(target_feature))] target_feature = target_feature[:n_data_points_per_domain] tsne.visualize(source_feature, target_feature, filename, source_color='cornflowerblue', target_color='darkorange') print('T-SNE process is done, figure is saved to {}'.format(filename)) def k_reciprocal_neigh(initial_rank, i, k1): """Compute k-reciprocal neighbors of i-th sample. Two samples f_i, f_j are k reciprocal-neighbors if and only if each one of them is among the k-nearest samples of another sample. """ forward_k_neigh_index = initial_rank[i, :k1 + 1] backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] fi = torch.nonzero(backward_k_neigh_index == i)[:, 0] return forward_k_neigh_index[fi] def compute_rerank_dist(target_features, k1=30, k2=6): """Compute distance according to `Re-ranking Person Re-identification with k-reciprocal Encoding (CVPR 2017) `_. """ n = target_features.size(0) original_dist = torch.pow(target_features, 2).sum(dim=1, keepdim=True) * 2 original_dist = original_dist.expand(n, n) - 2 * torch.mm(target_features, target_features.t()) original_dist /= original_dist.max(0)[0] original_dist = original_dist.t() initial_rank = torch.argsort(original_dist, dim=-1) all_num = gallery_num = original_dist.size(0) del target_features nn_k1 = [] nn_k1_half = [] for i in range(all_num): nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1)) nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1 / 2)))) V = torch.zeros(all_num, all_num) for i in range(all_num): k_reciprocal_index = nn_k1[i] k_reciprocal_expansion_index = k_reciprocal_index for candidate in k_reciprocal_index: candidate_k_reciprocal_index = nn_k1_half[candidate] if (len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2 / 3 * len( candidate_k_reciprocal_index)): k_reciprocal_expansion_index = torch.cat((k_reciprocal_expansion_index, candidate_k_reciprocal_index)) k_reciprocal_expansion_index = torch.unique(k_reciprocal_expansion_index) weight = torch.exp(-original_dist[i, k_reciprocal_expansion_index]) V[i, k_reciprocal_expansion_index] = weight / torch.sum(weight) if k2 != 1: k2_rank = initial_rank[:, :k2].clone().view(-1) V_qe = V[k2_rank] V_qe = V_qe.view(initial_rank.size(0), k2, -1).sum(1) V_qe /= k2 V = V_qe del V_qe del initial_rank invIndex = [] for i in range(gallery_num): invIndex.append(torch.nonzero(V[:, i])[:, 0]) jaccard_dist = torch.zeros_like(original_dist) for i in range(all_num): temp_min = torch.zeros(1, gallery_num) indNonZero = torch.nonzero(V[i, :])[:, 0] indImages = [invIndex[ind] for ind in indNonZero] for j in range(len(indNonZero)): temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + \ torch.min(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) jaccard_dist[i] = 1 - temp_min / (2 - temp_min) del invIndex del V pos_bool = (jaccard_dist < 0) jaccard_dist[pos_bool] = 0.0 return jaccard_dist ================================================ FILE: examples/domain_adaptation/semantic_segmentation/README.md ================================================ # Unsupervised Domain Adaptation for Semantic Segmentation It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results. ## Dataset You need to prepare following datasets manually if you want to use them: - [Cityscapes](https://www.cityscapes-dataset.com/) - [GTA5](https://download.visinf.tu-darmstadt.de/data/from_games/) - [Synthia](https://synthia-dataset.net/) #### Cityscapes, Foggy Cityscapes - Download Cityscapes and Foggy Cityscapes dataset from the [link](https://www.cityscapes-dataset.com/downloads/). Particularly, we use *leftImg8bit_trainvaltest.zip* for Cityscapes and *leftImg8bit_trainvaltest_foggy.zip* for Foggy Cityscapes. - Unzip them under the directory like ``` data/Cityscapes ├── gtFine ├── leftImg8bit │ ├── train │ ├── val │ └── test ├── leftImg8bit_foggy │ ├── train │ ├── val │ └── test └── ... ``` #### GTA-5 You need to download GTA5 manually from [GTA5](https://download.visinf.tu-darmstadt.de/data/from_games/). Ensure that there exist following directories before you use this dataset. ``` data/GTA5 ├── images ├── labels └── ... ``` #### Synthia You need to download Synthia manually from [Synthia](https://synthia-dataset.net/). Ensure that there exist following directories before you use this dataset. ``` data/synthia ├── RGB ├── synthia_mapped_to_cityscapes └── ... ``` ## Supported Methods Supported methods include: - [Cycle-Consistent Adversarial Networks (CycleGAN)](https://arxiv.org/pdf/1703.10593.pdf) - [CyCADA: Cycle-Consistent Adversarial Domain Adaptation](https://arxiv.org/abs/1711.03213) - [Adversarial Entropy Minimization (ADVENT)](https://arxiv.org/abs/1811.12833) - [Fourier Domain Adaptation (FDA)](https://arxiv.org/abs/2004.05498) ## Experiment and Results **Notations** - ``Origin`` means the accuracy reported by the original paper. - ``mIoU`` is the accuracy reported by `TLlib`. - ``ERM`` refers to the model trained with data from the source domain. - ``Oracle`` refers to the model trained with data from the target domain. ### GTA5->Cityscapes mIoU on deeplabv2 (ResNet-101) | GTA5 | Origin | mIoU | road | sidewalk | building | wall | fence | pole | traffic light | traffic sign | vegetation | terrian | sky | person | rider | car | truck | bus | train | motorbike | bicycle | |-------------|--------|------|------|----------|----------|------|-------|------|---------------|--------------|------------|---------|------|--------|-------|------|-------|------|-------|-----------|---------| | ERM | 27.1 | 37.3 | 66.5 | 17.4 | 73.3 | 13.4 | 21.5 | 22.8 | 30.1 | 17.1 | 82.2 | 7.1 | 73.6 | 57.4 | 28.4 | 78.6 | 36.1 | 13.4 | 1.5 | 31.9 | 36.2 | | AdvEnt | 43.8 | 43.8 | 89.3 | 33.9 | 80.3 | 24.0 | 25.2 | 27.8 | 36.7 | 18.2 | 84.3 | 33.9 | 81.3 | 59.8 | 28.4 | 84.3 | 34.1 | 44.4 | 0.1 | 33.2 | 12.9 | | FDA | 44.6 | 45.6 | 85.5 | 31.7 | 81.8 | 27.1 | 24.9 | 28.9 | 38.1 | 23.2 | 83.7 | 40.3 | 80.6 | 60.5 | 30.3 | 79.1 | 32.8 | 45.1 | 5.0 | 32.4 | 35.2 | | Cycada | 42.7 | 47.4 | 87.3 | 35.7 | 83.7 | 31.3 | 24.0 | 32.2 | 35.8 | 30.3 | 82.7 | 32.0 | 85.7 | 60.8 | 31.5 | 85.6 | 39.8 | 43.3 | 5.4 | 29.5 | 44.6 | | CycleGAN | | 47.0 | 88.4 | 41.9 | 83.6 | 34.4 | 23.9 | 32.9 | 35.5 | 26.0 | 83.1 | 36.8 | 82.3 | 59.9 | 27.0 | 83.4 | 31.6 | 42.3 | 11.0 | 28.2 | 40.5 | | Oracle | 65.1 | 70.5 | 97.4 | 79.7 | 90.1 | 53.0 | 50.0 | 48.0 | 55.5 | 67.2 | 90.2 | 60.0 | 93.0 | 72.7 | 55.2 | 92.7 | 76.5 | 78.5 | 56.0 | 54.6 | 68.8 | ### Synthia->Cityscapes mIoU on deeplabv2 (ResNet-101) | Synthia | Origin | mIoU | road | sidewalk | building | traffic light | traffic sign | vegetation | sky | person | rider | car | bus | motorbike | bicycle | |-------------|--------|------|------|----------|----------|---------------|--------------|------------|------|--------|-------|------|------|-----------|---------| | ERM | 22.1 | 41.5 | 59.6 | 21.1 | 77.4 | 7.7 | 17.6 | 78.0 | 84.5 | 53.2 | 16.9 | 65.9 | 24.9 | 8.5 | 24.8 | | AdvEnt | 47.6 | 47.9 | 88.3 | 44.9 | 80.5 | 4.5 | 9.1 | 81.3 | 86.2 | 52.9 | 21.0 | 82.0 | 30.3 | 11.9 | 30.2 | | FDA | - | 43.9 | 62.5 | 23.7 | 78.5 | 9.4 | 15.7 | 78.3 | 81.1 | 52.3 | 18.7 | 79.8 | 32.5 | 8.7 | 29.6 | | Oracle | 71.7 | 76.6 | 97.4 | 79.7 | 90.1 | 55.5 | 67.2 | 90.2 | 93.0 | 72.7 | 55.2 | 92.7 | 78.5 | 54.6 | 68.8 | ### Cityscapes->Foggy Cityscapes mIoU on deeplabv2 (ResNet-101) | Foggy | Origin | mIoU | road | sidewalk | building | wall | fence | pole | traffic light | traffic sign | vegetation | terrian | sky | person | rider | car | truck | bus | train | motorbike | bicycle | |-------------|--------|------|------|----------|----------|------|-------|------|---------------|--------------|------------|---------|------|--------|-------|------|-------|------|-------|-----------|---------| | ERM | | 51.2 | 95.3 | 70.2 | 64.1 | 31.9 | 35.2 | 30.7 | 33.3 | 51.1 | 42.3 | 44.0 | 32.1 | 64.4 | 47.0 | 86.0 | 64.4 | 56.4 | 21.1 | 43.1 | 60.8 | | AdvEnt | | 61.8 | 96.8 | 75.1 | 76.4 | 46.2 | 42.6 | 39.3 | 43.6 | 58.9 | 74.3 | 50.1 | 75.9 | 67.3 | 51.0 | 89.4 | 70.5 | 64.7 | 39.9 | 47.9 | 65.0 | | FDA | | 61.9 | 96.9 | 77.2 | 75.3 | 46.5 | 42.0 | 39.8 | 47.1 | 61.0 | 72.7 | 54.6 | 63.8 | 68.4 | 50.1 | 90.1 | 72.8 | 68.0 | 35.5 | 50.8 | 64.2 | | Cycada | | 63.3 | 96.8 | 75.5 | 79.1 | 38.0 | 40.3 | 42.1 | 48.2 | 61.2 | 76.9 | 52.1 | 77.6 | 68.6 | 51.7 | 90.4 | 71.7 | 70.4 | 43.3 | 52.6 | 65.7 | | CycleGAN | | 66.0 | 97.1 | 77.6 | 84.3 | 42.7 | 46.3 | 42.8 | 47.5 | 61.0 | 84.0 | 55.2 | 83.4 | 69.4 | 51.8 | 90.7 | 73.7 | 76.2 | 54.2 | 50.7 | 65.6 | | Oracle | | 66.9 | 97.4 | 78.6 | 88.1 | 50.7 | 50.5 | 46.2 | 51.3 | 64.4 | 88.1 | 55.3 | 87.4 | 70.9 | 52.7 | 91.6 | 72.4 | 73.2 | 31.8 | 52.2 | 67.4 | ## Visualization If you want to visualize the segmentation results during training, you should set ``--debug``. ``` CUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes --log logs/src_only/gtav2cityscapes --debug ``` Then you can find images, predictions and labels in directory ``logs/src_only/gtav2cityscapes/visualize/``. Translation model such as CycleGAN will save images by default. Here is the source-style images and its translated version. ## TODO Support methods: AdaptSeg ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{CycleGAN, title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks}, author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, booktitle={ICCV}, year={2017} } @inproceedings{cycada, title={Cycada: Cycle-consistent adversarial domain adaptation}, author={Hoffman, Judy and Tzeng, Eric and Park, Taesung and Zhu, Jun-Yan and Isola, Phillip and Saenko, Kate and Efros, Alexei and Darrell, Trevor}, booktitle={ICML}, year={2018}, } @inproceedings{Advent, author = {Vu, Tuan-Hung and Jain, Himalaya and Bucher, Maxime and Cord, Matthieu and Perez, Patrick}, title = {ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation}, booktitle = {CVPR}, year = {2019} } @inproceedings{FDA, author = {Yanchao Yang and Stefano Soatto}, title = {{FDA:} Fourier Domain Adaptation for Semantic Segmentation}, booktitle = {CVPR}, year = {2020} } ``` ================================================ FILE: examples/domain_adaptation/semantic_segmentation/advent.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse from PIL import Image import numpy as np import shutil import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD, Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader sys.path.append('../../..') from tllib.alignment.advent import Discriminator, DomainAdversarialEntropyLoss import tllib.vision.models.segmentation as models import tllib.vision.datasets.segmentation as datasets import tllib.vision.transforms.segmentation as T from tllib.vision.transforms import DeNormalizeAndTranspose from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter, Meter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset( root=args.source_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.ColorJitter(brightness=0.3, contrast=0.3), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset( root=args.target_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=(2., 2.), scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset( root=args.target_root, split='val', transforms=T.Compose([ T.Resize(image_size=args.test_input_size, label_size=args.test_output_size), T.NormalizeAndTranspose(), ]), ) val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model num_classes = train_source_dataset.num_classes model = models.__dict__[args.arch](num_classes=num_classes).to(device) discriminator = Discriminator(num_classes=num_classes).to(device) # define optimizer and lr scheduler optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) optimizer_d = Adam(discriminator.parameters(), lr=args.lr_d, betas=(0.9, 0.99)) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power)) lr_scheduler_d = LambdaLR(optimizer_d, lambda x: (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power)) # optionally resume from a checkpoint if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) discriminator.load_state_dict(checkpoint['discriminator']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) optimizer_d.load_state_dict(checkpoint['optimizer_d']) lr_scheduler_d.load_state_dict(checkpoint['lr_scheduler_d']) args.start_epoch = checkpoint['epoch'] + 1 # define loss function (criterion) criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device) dann = DomainAdversarialEntropyLoss(discriminator) interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True) interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True) # define visualization function decode = train_source_dataset.decode_target def visualize(image, pred, label, prefix): """ Args: image (tensor): 3 x H x W pred (tensor): C x H x W label (tensor): H x W prefix: prefix of the saving image """ image = image.detach().cpu().numpy() pred = pred.detach().max(dim=0)[1].cpu().numpy() label = label.cpu().numpy() for tensor, name in [ (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"), (decode(label), "label"), (decode(pred), "pred") ]: tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name))) if args.phase == 'test': confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args) print(confmat) return # start training best_iou = 0. for epoch in range(args.start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler.get_lr(), lr_scheduler_d.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, interp_train, criterion, dann, optimizer, lr_scheduler, optimizer_d, lr_scheduler_d, epoch, visualize if args.debug else None, args) # evaluate on validation set confmat = validate(val_target_loader, model, interp_val, criterion, None, args) print(confmat.format(train_source_dataset.classes)) acc_global, acc, iu = confmat.compute() # calculate the mean iou over partial classes indexes = [train_source_dataset.classes.index(name) for name in train_source_dataset.evaluate_classes] iu = iu[indexes] mean_iou = iu.mean() # remember best acc@1 and save checkpoint torch.save( { 'model': model.state_dict(), 'discriminator': discriminator.state_dict(), 'optimizer': optimizer.state_dict(), 'optimizer_d': optimizer_d.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'lr_scheduler_d': lr_scheduler_d.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if mean_iou > best_iou: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_iou = max(best_iou, mean_iou) print("Target: {} Best: {}".format(mean_iou, best_iou)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, interp, criterion, dann, optimizer: SGD, lr_scheduler: LambdaLR, optimizer_d: SGD, lr_scheduler_d: LambdaLR, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_transfer = AverageMeter('Loss (transfer)', ':3.2f') losses_discriminator = AverageMeter('Loss (discriminator)', ':3.2f') accuracies_s = Meter('Acc (s)', ':3.2f') accuracies_t = Meter('Acc (t)', ':3.2f') iou_s = Meter('IoU (s)', ':3.2f') iou_t = Meter('IoU (t)', ':3.2f') confmat_s = ConfusionMatrix(model.num_classes) confmat_t = ConfusionMatrix(model.num_classes) progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, losses_transfer, losses_discriminator, accuracies_s, accuracies_t, iou_s, iou_t], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_s, label_s = next(train_source_iter) x_t, label_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.long().to(device) x_t = x_t.to(device) label_t = label_t.long().to(device) # measure data loading time data_time.update(time.time() - end) optimizer.zero_grad() optimizer_d.zero_grad() # Step 1: Train the segmentation network, freeze the discriminator dann.eval() y_s = model(x_s) pred_s = interp(y_s) loss_cls_s = criterion(pred_s, label_s) loss_cls_s.backward() # adversarial training to fool the discriminator y_t = model(x_t) pred_t = interp(y_t) loss_transfer = dann(pred_t, 'source') (loss_transfer * args.trade_off).backward() # Step 2: Train the discriminator dann.train() loss_discriminator = 0.5 * (dann(pred_s.detach(), 'source') + dann(pred_t.detach(), 'target')) loss_discriminator.backward() # compute gradient and do SGD step optimizer.step() optimizer_d.step() lr_scheduler.step() lr_scheduler_d.step() # measure accuracy and record loss losses_s.update(loss_cls_s.item(), x_s.size(0)) losses_transfer.update(loss_transfer.item(), x_s.size(0)) losses_discriminator.update(loss_discriminator.item(), x_s.size(0)) confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten()) confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten()) acc_global_s, acc_s, iu_s = confmat_s.compute() acc_global_t, acc_t, iu_t = confmat_t.compute() accuracies_s.update(acc_s.mean().item()) accuracies_t.update(acc_t.mean().item()) iou_s.update(iu_s.mean().item()) iou_t.update(iu_t.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i)) visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i)) def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') acc = Meter('Acc', ':3.2f') iou = Meter('IoU', ':3.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, acc, iou], prefix='Test: ') # switch to evaluate mode model.eval() confmat = ConfusionMatrix(model.num_classes) with torch.no_grad(): end = time.time() for i, (x, label) in enumerate(val_loader): x = x.to(device) label = label.long().to(device) # compute output output = interp(model(x)) loss = criterion(output, label) # measure accuracy and record loss losses.update(loss.item(), x.size(0)) confmat.update(label.flatten(), output.argmax(1).flatten()) acc_global, accs, iu = confmat.compute() acc.update(accs.mean().item()) iou.update(iu.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x[0], output[0], label[0], "val_{}".format(i)) return confmat if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='ADVENT for Segmentation Domain Adaptation') # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.), help='the resize ratio for the random resize crop') parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512), help='the input and output image size during training') parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512), help='the input image size during test') parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024), help='the output image size during test') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: deeplabv2_resnet101)') parser.add_argument("--resume", type=str, default=None, help="Where restore model parameters from.") parser.add_argument('--trade-off', type=float, default=0.001, help='trade-off parameter for the advent loss') # training parameters parser.add_argument('-b', '--batch-size', default=2, type=int, metavar='N', help='mini-batch size (default: 2)') parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument("--momentum", type=float, default=0.9, help="Momentum component of the optimiser.") parser.add_argument("--weight-decay", type=float, default=0.0005, help="Regularisation parameter for L2-loss.") parser.add_argument("--lr-power", type=float, default=0.9, help="Decay parameter to compute the learning rate (only for deeplab).") parser.add_argument("--lr-d", default=1e-4, type=float, metavar='LR', help='initial learning rate for discriminator') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--ignore-label", type=int, default=255, help="The index of the label to ignore during the training.") parser.add_argument("--log", type=str, default='advent', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--debug', action="store_true", help='In the debug mode, save images and predictions during training') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/semantic_segmentation/advent.sh ================================================ # GTA5 to Cityscapes CUDA_VISIBLE_DEVICES=0 python advent.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \ --log logs/advent/gtav2cityscapes # Synthia to Cityscapes CUDA_VISIBLE_DEVICES=0 python advent.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \ --log logs/advent/synthia2cityscapes # Cityscapes to Foggy CUDA_VISIBLE_DEVICES=0 python advent.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \ --log logs/advent/cityscapes2foggy ================================================ FILE: examples/domain_adaptation/semantic_segmentation/cycada.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import itertools import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader from torchvision.transforms import ToPILImage, Compose, Lambda sys.path.append('../../..') import tllib.translation.cyclegan as cyclegan from tllib.translation.cyclegan.util import ImagePool, set_requires_grad from tllib.translation.cycada import SemanticConsistency import tllib.vision.models.segmentation as models import tllib.vision.datasets.segmentation as datasets from tllib.vision.transforms import Denormalize, NormalizeAndTranspose import tllib.vision.transforms.segmentation as T from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # define networks (both generators and discriminators) netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device) netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device) # create image buffer to store previously generated images fake_S_pool = ImagePool(args.pool_size) fake_T_pool = ImagePool(args.pool_size) # define optimizer and lr scheduler optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay) lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function) lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function) # optionally resume from a checkpoint if args.resume: print("Resume from", args.resume) checkpoint = torch.load(args.resume, map_location='cpu') netG_S2T.load_state_dict(checkpoint['netG_S2T']) netG_T2S.load_state_dict(checkpoint['netG_T2S']) netD_S.load_state_dict(checkpoint['netD_S']) netD_T.load_state_dict(checkpoint['netD_T']) optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D.load_state_dict(checkpoint['optimizer_D']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D']) args.start_epoch = checkpoint['epoch'] + 1 if args.phase == 'test': transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) return # define loss function criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss() criterion_cycle = nn.L1Loss() criterion_identity = nn.L1Loss() criterion_semantic = SemanticConsistency(ignore_index=[args.ignore_label]+train_source_dataset.ignore_classes).to(device) interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True) # define segmentation model and predict function model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device) if args.pretrain: print("Loading pretrain segmentation model from", args.pretrain) checkpoint = torch.load(args.pretrain, map_location='cpu') model.load_state_dict(checkpoint['model']) model.eval() cycle_gan_tensor_to_segmentation_tensor = Compose([ Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), Lambda(lambda image: image.mul(255).permute((1, 2, 0))), NormalizeAndTranspose(), ]) def predict(image): image = cycle_gan_tensor_to_segmentation_tensor(image.squeeze()) image = image.unsqueeze(dim=0).to(device) prediction = model(image) return interp_train(prediction) # define visualization function tensor_to_image = Compose([ Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ToPILImage() ]) decode = train_source_dataset.decode_target def visualize(image, name, pred=None): """ Args: image (tensor): image in shape 3 x H x W name: name of the saving image pred (tensor): predictions in shape C x H x W """ tensor_to_image(image).save(logger.get_image_path("{}.png".format(name))) if pred is not None: pred = pred.detach().max(dim=0).indices.cpu().numpy() pred = decode(pred) pred.save(logger.get_image_path("pred_{}.png".format(name))) # start training for epoch in range(args.start_epoch, args.epochs+args.epochs_decay): logger.set_epoch(epoch) print(lr_scheduler_G.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, predict, criterion_gan, criterion_cycle, criterion_identity, criterion_semantic, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch, visualize, args) # update learning rates lr_scheduler_G.step() lr_scheduler_D.step() # save checkpoint torch.save( { 'netG_S2T': netG_S2T.state_dict(), 'netG_T2S': netG_T2S.state_dict(), 'netD_S': netD_S.state_dict(), 'netD_T': netD_T.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D': optimizer_D.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D': lr_scheduler_D.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if args.translated_root is not None: transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) logger.close() def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, predict, criterion_gan, criterion_cycle, criterion_identity, criterion_semantic, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_G_S2T = AverageMeter('G_S2T', ':3.2f') losses_G_T2S = AverageMeter('G_T2S', ':3.2f') losses_D_S = AverageMeter('D_S', ':3.2f') losses_D_T = AverageMeter('D_T', ':3.2f') losses_cycle_S = AverageMeter('cycle_S', ':3.2f') losses_cycle_T = AverageMeter('cycle_T', ':3.2f') losses_identity_S = AverageMeter('idt_S', ':3.2f') losses_identity_T = AverageMeter('idt_T', ':3.2f') losses_semantic_S2T = AverageMeter('sem_S2T', ':3.2f') losses_semantic_T2S = AverageMeter('sem_T2S', ':3.2f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T, losses_semantic_S2T, losses_semantic_T2S], prefix="Epoch: [{}]".format(epoch)) end = time.time() for i in range(args.iters_per_epoch): real_S, label_s = next(train_source_iter) real_T, _ = next(train_target_iter) real_S = real_S.to(device) real_T = real_T.to(device) label_s = label_s.to(device) # measure data loading time data_time.update(time.time() - end) # Compute fake images and reconstruction images. fake_T = netG_S2T(real_S) rec_S = netG_T2S(fake_T) fake_S = netG_T2S(real_T) rec_T = netG_S2T(fake_S) # Optimizing generators # discriminators require no gradients set_requires_grad(netD_S, False) set_requires_grad(netD_T, False) optimizer_G.zero_grad() # GAN loss D_T(G_S2T(S)) loss_G_S2T = criterion_gan(netD_T(fake_T), real=True) # GAN loss D_S(G_T2S(B)) loss_G_T2S = criterion_gan(netD_S(fake_S), real=True) # Cycle loss || G_T2S(G_S2T(S)) - S|| loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle # Cycle loss || G_S2T(G_T2S(T)) - T|| loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle # Identity loss # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T|| identity_T = netG_S2T(real_T) loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S|| identity_S = netG_T2S(real_S) loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity # Semantic loss pred_fake_T = predict(fake_T) pred_real_S = predict(real_S) loss_semantic_S2T = criterion_semantic(pred_fake_T, label_s) * args.trade_off_semantic pred_fake_S = predict(fake_S) pred_real_T = predict(real_T) loss_semantic_T2S = criterion_semantic(pred_fake_S, pred_real_T.max(1).indices) * args.trade_off_semantic # combined loss and calculate gradients loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + \ loss_identity_S + loss_identity_T + loss_semantic_S2T + loss_semantic_T2S loss_G.backward() optimizer_G.step() # Optimize discriminator set_requires_grad(netD_S, True) set_requires_grad(netD_T, True) optimizer_D.zero_grad() # Calculate GAN loss for discriminator D_S fake_S_ = fake_S_pool.query(fake_S.detach()) loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False)) loss_D_S.backward() # Calculate GAN loss for discriminator D_T fake_T_ = fake_T_pool.query(fake_T.detach()) loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False)) loss_D_T.backward() optimizer_D.step() # measure elapsed time losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0)) losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0)) losses_D_S.update(loss_D_S.item(), real_S.size(0)) losses_D_T.update(loss_D_T.item(), real_S.size(0)) losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0)) losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0)) losses_identity_S.update(loss_identity_S.item(), real_S.size(0)) losses_identity_T.update(loss_identity_T.item(), real_S.size(0)) losses_semantic_S2T.update(loss_semantic_S2T.item(), real_S.size(0)) losses_semantic_T2S.update(loss_semantic_T2S.item(), real_S.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) for image, prediction, name in zip([real_S, real_T, fake_S, fake_T], [pred_real_S, pred_real_T, pred_fake_S, pred_fake_T], ["real_S", "real_T", "fake_S", "fake_T"]): visualize(image[0], "{}_{}".format(i, name), prediction[0]) for image, name in zip([rec_S, rec_T, identity_S, identity_T], ["rec_S", "rec_T", "identity_S", "identity_T"]): visualize(image[0], "{}_{}".format(i, name)) if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) # dataset parameters parser = argparse.ArgumentParser(description='Cycada for Segmentation Domain Adaptation') parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.), help='the resize ratio for the random resize crop') parser.add_argument('--train-size', nargs='+', type=int, default=(512, 256), help='the input and output image size during training') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: deeplabv2_resnet101)') parser.add_argument('--pretrain', type=str, default=None, help='pretrain checkpoints for segementation model') parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--netD', type=str, default='patch', help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.') parser.add_argument('--netG', type=str, default='unet_256', help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]') parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') parser.add_argument("--resume", type=str, default=None, help="Where restore cyclegan model parameters from.") parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss') parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss') parser.add_argument('--trade-off-semantic', type=float, default=1.0, help='trade off for semantic loss') # training parameters parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size (default: 1)') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--epochs-decay', type=int, default=20, help='number of epochs to linearly decay learning rate to zero') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('-i', '--iters-per-epoch', default=5000, type=int, help='Number of iterations per epoch') parser.add_argument('--pool-size', type=int, default=50, help='the size of image buffer that stores previously generated images') parser.add_argument('-p', '--print-freq', default=400, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--ignore-label", type=int, default=255, help="The index of the label to ignore during the training.") parser.add_argument("--log", type=str, default='cycada', help="Where to save logs, checkpoints and debugging images.") # test parameters parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--translated-root', type=str, default=None, help="The root to put the translated dataset") parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512), help='the input image size during test') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/semantic_segmentation/cycada.sh ================================================ # GTA5 to Cityscapes # First, train the CycleGAN CUDA_VISIBLE_DEVICES=0 python cycada.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \ --log logs/cycada/gtav2cityscapes --pretrain logs/src_only/gtav2cityscapes/checkpoints/59.pth \ --translated-root data/GTA52Cityscapes/cycada_39 # Then, train the src_only model on the translated source dataset CUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA52Cityscapes/cycada_39 data/Cityscapes \ -s GTA5 -t Cityscapes --log logs/cycada_src_only/gtav2cityscapes ## Synthia to Cityscapes # First, train the Cycada CUDA_VISIBLE_DEVICES=0 python cycada.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \ --log logs/cycada/synthia2cityscapes --pretrain logs/src_only/synthia2cityscapes/checkpoints/59.pth \ --translated-root data/Synthia2Cityscapes/cycada_39 # Then, train the src_only model on the translated source dataset CUDA_VISIBLE_DEVICES=0 python source_only.py data/Synthia2Cityscapes/cycada_39 data/Cityscapes \ -s Synthia -t Cityscapes --log logs/cycada_src_only/synthia2cityscapes # Cityscapes to FoggyCityscapes # First, train the CycleGAN CUDA_VISIBLE_DEVICES=0 python cycada.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \ --log logs/cycada/cityscapes2foggy --pretrain logs/src_only/cityscapes2foggy/checkpoints/59.pth \ --translated-root data/Cityscapes2Foggy/cycada_39 # Then, train the src_only model on the translated source dataset CUDA_VISIBLE_DEVICES=0 python source_only.py data/Cityscapes2Foggy/cycada_39 data/Cityscapes \ -s Cityscapes -t FoggyCityscapes --log logs/cycada_src_only/cityscapes2foggy ================================================ FILE: examples/domain_adaptation/semantic_segmentation/cycle_gan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse import itertools import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader from torchvision.transforms import ToPILImage, Compose sys.path.append('../../..') import tllib.translation.cyclegan as cyclegan from tllib.translation.cyclegan.util import ImagePool, set_requires_grad import tllib.vision.datasets.segmentation as datasets from tllib.vision.transforms import Denormalize import tllib.vision.transforms.segmentation as T from tllib.utils.data import ForeverDataIterator from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset(root=args.source_root, transforms=train_transform) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset(root=args.target_root, transforms=train_transform) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # define networks (both generators and discriminators) netG_S2T = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netG_T2S = cyclegan.generator.__dict__[args.netG](ngf=args.ngf, norm=args.norm, use_dropout=False).to(device) netD_S = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device) netD_T = cyclegan.discriminator.__dict__[args.netD](ndf=args.ndf, norm=args.norm).to(device) # create image buffer to store previously generated images fake_S_pool = ImagePool(args.pool_size) fake_T_pool = ImagePool(args.pool_size) # define optimizer and lr scheduler optimizer_G = Adam(itertools.chain(netG_S2T.parameters(), netG_T2S.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) optimizer_D = Adam(itertools.chain(netD_S.parameters(), netD_T.parameters()), lr=args.lr, betas=(args.beta1, 0.999)) lr_decay_function = lambda epoch: 1.0 - max(0, epoch - args.epochs) / float(args.epochs_decay) lr_scheduler_G = LambdaLR(optimizer_G, lr_lambda=lr_decay_function) lr_scheduler_D = LambdaLR(optimizer_D, lr_lambda=lr_decay_function) # optionally resume from a checkpoint if args.resume: print("Resume from", args.resume) checkpoint = torch.load(args.resume, map_location='cpu') netG_S2T.load_state_dict(checkpoint['netG_S2T']) netG_T2S.load_state_dict(checkpoint['netG_T2S']) netD_S.load_state_dict(checkpoint['netD_S']) netD_T.load_state_dict(checkpoint['netD_T']) optimizer_G.load_state_dict(checkpoint['optimizer_G']) optimizer_D.load_state_dict(checkpoint['optimizer_D']) lr_scheduler_G.load_state_dict(checkpoint['lr_scheduler_G']) lr_scheduler_D.load_state_dict(checkpoint['lr_scheduler_D']) args.start_epoch = checkpoint['epoch'] + 1 if args.phase == 'test': transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) return # define loss function criterion_gan = cyclegan.LeastSquaresGenerativeAdversarialLoss() criterion_cycle = nn.L1Loss() criterion_identity = nn.L1Loss() # define visualization function tensor_to_image = Compose([ Denormalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ToPILImage() ]) def visualize(image, name): """ Args: image (tensor): image in shape 3 x H x W name: name of the saving image """ tensor_to_image(image).save(logger.get_image_path("{}.png".format(name))) # start training for epoch in range(args.start_epoch, args.epochs+args.epochs_decay): logger.set_epoch(epoch) print(lr_scheduler_G.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch, visualize, args) # update learning rates lr_scheduler_G.step() lr_scheduler_D.step() # save checkpoint torch.save( { 'netG_S2T': netG_S2T.state_dict(), 'netG_T2S': netG_T2S.state_dict(), 'netD_S': netD_S.state_dict(), 'netD_T': netD_T.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'optimizer_D': optimizer_D.state_dict(), 'lr_scheduler_G': lr_scheduler_G.state_dict(), 'lr_scheduler_D': lr_scheduler_D.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if args.translated_root is not None: transform = T.Compose([ T.Resize(image_size=args.test_input_size), T.wrapper(cyclegan.transform.Translation)(netG_S2T, device), ]) train_source_dataset.translate(transform, args.translated_root) logger.close() def train(train_source_iter, train_target_iter, netG_S2T, netG_T2S, netD_S, netD_T, criterion_gan, criterion_cycle, criterion_identity, optimizer_G, optimizer_D, fake_S_pool, fake_T_pool, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_G_S2T = AverageMeter('G_S2T', ':3.2f') losses_G_T2S = AverageMeter('G_T2S', ':3.2f') losses_D_S = AverageMeter('D_S', ':3.2f') losses_D_T = AverageMeter('D_T', ':3.2f') losses_cycle_S = AverageMeter('cycle_S', ':3.2f') losses_cycle_T = AverageMeter('cycle_T', ':3.2f') losses_identity_S = AverageMeter('idt_S', ':3.2f') losses_identity_T = AverageMeter('idt_T', ':3.2f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_G_S2T, losses_G_T2S, losses_D_S, losses_D_T, losses_cycle_S, losses_cycle_T, losses_identity_S, losses_identity_T], prefix="Epoch: [{}]".format(epoch)) end = time.time() for i in range(args.iters_per_epoch): real_S, _ = next(train_source_iter) real_T, _ = next(train_target_iter) real_S = real_S.to(device) real_T = real_T.to(device) # measure data loading time data_time.update(time.time() - end) # Compute fake images and reconstruction images. fake_T = netG_S2T(real_S) rec_S = netG_T2S(fake_T) fake_S = netG_T2S(real_T) rec_T = netG_S2T(fake_S) # Optimizing generators # discriminators require no gradients set_requires_grad(netD_S, False) set_requires_grad(netD_T, False) optimizer_G.zero_grad() # GAN loss D_T(G_S2T(S)) loss_G_S2T = criterion_gan(netD_T(fake_T), real=True) # GAN loss D_S(G_T2S(B)) loss_G_T2S = criterion_gan(netD_S(fake_S), real=True) # Cycle loss || G_T2S(G_S2T(S)) - S|| loss_cycle_S = criterion_cycle(rec_S, real_S) * args.trade_off_cycle # Cycle loss || G_S2T(G_T2S(T)) - T|| loss_cycle_T = criterion_cycle(rec_T, real_T) * args.trade_off_cycle # Identity loss # G_S2T should be identity if real_T is fed: ||G_S2T(real_T) - real_T|| identity_T = netG_S2T(real_T) loss_identity_T = criterion_identity(identity_T, real_T) * args.trade_off_identity # G_T2S should be identity if real_S is fed: ||G_T2S(real_S) - real_S|| identity_S = netG_T2S(real_S) loss_identity_S = criterion_identity(identity_S, real_S) * args.trade_off_identity # combined loss and calculate gradients loss_G = loss_G_S2T + loss_G_T2S + loss_cycle_S + loss_cycle_T + loss_identity_S + loss_identity_T loss_G.backward() optimizer_G.step() # Optimize discriminator set_requires_grad(netD_S, True) set_requires_grad(netD_T, True) optimizer_D.zero_grad() # Calculate GAN loss for discriminator D_S fake_S_ = fake_S_pool.query(fake_S.detach()) loss_D_S = 0.5 * (criterion_gan(netD_S(real_S), True) + criterion_gan(netD_S(fake_S_), False)) loss_D_S.backward() # Calculate GAN loss for discriminator D_T fake_T_ = fake_T_pool.query(fake_T.detach()) loss_D_T = 0.5 * (criterion_gan(netD_T(real_T), True) + criterion_gan(netD_T(fake_T_), False)) loss_D_T.backward() optimizer_D.step() # measure elapsed time losses_G_S2T.update(loss_G_S2T.item(), real_S.size(0)) losses_G_T2S.update(loss_G_T2S.item(), real_S.size(0)) losses_D_S.update(loss_D_S.item(), real_S.size(0)) losses_D_T.update(loss_D_T.item(), real_S.size(0)) losses_cycle_S.update(loss_cycle_S.item(), real_S.size(0)) losses_cycle_T.update(loss_cycle_T.item(), real_S.size(0)) losses_identity_S.update(loss_identity_S.item(), real_S.size(0)) losses_identity_T.update(loss_identity_T.item(), real_S.size(0)) batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) for tensor, name in zip([real_S, real_T, fake_S, fake_T, rec_S, rec_T, identity_S, identity_T], ["real_S", "real_T", "fake_S", "fake_T", "rec_S", "rec_T", "identity_S", "identity_T"]): visualize(tensor[0], "{}_{}".format(i, name)) if __name__ == '__main__': dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='CycleGAN for Segmentation Domain Adaptation') # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.), help='the resize ratio for the random resize crop') parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512), help='the input and output image size during training') # model parameters parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--netD', type=str, default='patch', help='specify discriminator architecture [patch | pixel]. The basic model is a 70x70 PatchGAN.') parser.add_argument('--netG', type=str, default='unet_256', help='specify generator architecture [resnet_9 | resnet_6 | unet_256 | unet_128]') parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') parser.add_argument("--resume", type=str, default=None, help="Where restore model parameters from.") parser.add_argument('--trade-off-cycle', type=float, default=10.0, help='trade off for cycle loss') parser.add_argument('--trade-off-identity', type=float, default=5.0, help='trade off for identity loss') # training parameters parser.add_argument('-b', '--batch-size', default=1, type=int, metavar='N', help='mini-batch size (default: 1)') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--epochs-decay', type=int, default=20, help='number of epochs to linearly decay learning rate to zero') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('-i', '--iters-per-epoch', default=5000, type=int, help='Number of iterations per epoch') parser.add_argument('--pool-size', type=int, default=50, help='the size of image buffer that stores previously generated images') parser.add_argument('-p', '--print-freq', default=400, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='cyclegan', help="Where to save logs, checkpoints and debugging images.") # test parameters parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--translated-root', type=str, default=None, help="The root to put the translated dataset") parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512), help='the input image size during test') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/semantic_segmentation/cycle_gan.sh ================================================ # GTA5 to Cityscapes # First, train the CycleGAN CUDA_VISIBLE_DEVICES=0 python cycle_gan.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \ --log logs/cyclegan/gtav2cityscapes --translated-root data/GTA52Cityscapes/CycleGAN_39 # Then, train the src_only model on the translated source dataset CUDA_VISIBLE_DEVICES=0 python source_only.py data/GTA52Cityscapes/CycleGAN_39 data/Cityscapes \ -s GTA5 -t Cityscapes --log logs/cyclegan_src_only/gtav2cityscapes # Cityscapes to FoggyCityscapes # First, train the CycleGAN CUDA_VISIBLE_DEVICES=0 python cycle_gan.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \ --log logs/cyclegan/cityscapes2foggy --translated-root data/Cityscapes2Foggy/CycleGAN_39 # Then, train the src_only model on the translated source dataset CUDA_VISIBLE_DEVICES=0 python source_only.py data/Cityscapes2Foggy/CycleGAN_39 data/Cityscapes \ -s Cityscapes -t FoggyCityscapes --log logs/cyclegan_src_only/cityscapes2foggy ================================================ FILE: examples/domain_adaptation/semantic_segmentation/erm.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse from PIL import Image import numpy as np import shutil import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader sys.path.append('../../..') import tllib.vision.models.segmentation as models import tllib.vision.datasets.segmentation as datasets import tllib.vision.transforms.segmentation as T from tllib.vision.transforms import DeNormalizeAndTranspose from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter, Meter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset( root=args.source_root, transforms=T.Compose([ T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=args.resize_scale), T.ColorJitter(brightness=0.3, contrast=0.3), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) target_dataset = datasets.__dict__[args.target] val_target_dataset = target_dataset( root=args.target_root, split='val', transforms=T.Compose([ T.Resize(image_size=args.test_input_size, label_size=args.test_output_size), T.NormalizeAndTranspose(), ]), ) val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True) train_source_iter = ForeverDataIterator(train_source_loader) # create model model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device) # define optimizer and lr scheduler optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power)) # optionally resume from a checkpoint if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 # define loss function (criterion) criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device) interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True) interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True) # define visualization function decode = train_source_dataset.decode_target def visualize(image, pred, label, prefix): """ Args: image (tensor): 3 x H x W pred (tensor): C x H x W label (tensor): H x W prefix: prefix of the saving image """ image = image.detach().cpu().numpy() pred = pred.detach().max(dim=0)[1].cpu().numpy() label = label.cpu().numpy() for tensor, name in [ (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"), (decode(label), "label"), (decode(pred), "pred") ]: tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name))) if args.phase == 'test': confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args) print(confmat) return # start training best_iou = 0. for epoch in range(args.start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, model, interp_train, criterion, optimizer, lr_scheduler, epoch, visualize if args.debug else None, args) # evaluate on validation set confmat = validate(val_target_loader, model, interp_val, criterion, visualize if args.debug else None, args) print(confmat.format(train_source_dataset.classes)) acc_global, acc, iu = confmat.compute() # calculate the mean iou over partial classes indexes = [train_source_dataset.classes.index(name) for name in train_source_dataset.evaluate_classes] iu = iu[indexes] mean_iou = iu.mean() # remember best iou and save checkpoint torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if mean_iou > best_iou: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_iou = max(best_iou, mean_iou) print("Target: {} Best: {}".format(mean_iou, best_iou)) logger.close() def train(train_source_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') accuracies_s = Meter('Acc (s)', ':3.2f') iou_s = Meter('IoU (s)', ':3.2f') confmat_s = ConfusionMatrix(model.num_classes) progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, accuracies_s, iou_s], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, label_s = next(train_source_iter) x_s = x_s.to(device) label_s = label_s.long().to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s = model(x_s) pred_s = interp(y_s) loss_cls_s = criterion(pred_s, label_s) loss_cls_s.backward() # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure accuracy and record loss losses_s.update(loss_cls_s.item(), x_s.size(0)) confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten()) acc_global_s, acc_s, iu_s = confmat_s.compute() accuracies_s.update(acc_s.mean().item()) iou_s.update(iu_s.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i)) def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') acc = Meter('Acc', ':3.2f') iou = Meter('IoU', ':3.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, acc, iou], prefix='Test: ') # switch to evaluate mode model.eval() confmat = ConfusionMatrix(model.num_classes) with torch.no_grad(): end = time.time() for i, (x, label) in enumerate(val_loader): x = x.to(device) label = label.long().to(device) # compute output output = interp(model(x)) loss = criterion(output, label) # measure accuracy and record loss losses.update(loss.item(), x.size(0)) confmat.update(label.flatten(), output.argmax(1).flatten()) acc_global, accs, iu = confmat.compute() acc.update(accs.mean().item()) iou.update(iu.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x[0], output[0], label[0], "val_{}".format(i)) return confmat if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='Source Only for Segmentation Domain Adaptation') # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.), help='the resize ratio for the random resize crop') parser.add_argument('--resize-scale', nargs='+', type=float, default=(0.5, 1.), help='the resize scale for the random resize crop') parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512), help='the input and output image size during training') parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512), help='the input image size during test') parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024), help='the output image size during test') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: deeplabv2_resnet101)') parser.add_argument("--resume", type=str, default=None, help="Where restore model parameters from.") # training parameters parser.add_argument('-b', '--batch-size', default=2, type=int, metavar='N', help='mini-batch size (default: 2)') parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument("--momentum", type=float, default=0.9, help="Momentum component of the optimiser.") parser.add_argument("--weight-decay", type=float, default=0.0005, help="Regularisation parameter for L2-loss.") parser.add_argument("--lr-power", type=float, default=0.9, help="Decay parameter to compute the learning rate (only for deeplab).") parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--ignore-label", type=int, default=255, help="The index of the label to ignore during the training.") parser.add_argument("--log", type=str, default='src_only', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--debug', action="store_true", help='In the debug mode, save images and predictions during training') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/semantic_segmentation/erm.sh ================================================ # Source Only # GTA5 to Cityscapes CUDA_VISIBLE_DEVICES=0 python erm.py data/GTA5 data/Cityscapes \ -s GTA5 -t Cityscapes --log logs/erm/gtav2cityscapes # Synthia to Cityscapes CUDA_VISIBLE_DEVICES=0 python erm.py data/synthia data/Cityscapes \ -s Synthia -t Cityscapes --log logs/erm/synthia2cityscapes # Cityscapes to FoggyCityscapes CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \ -s Cityscapes -t FoggyCityscapes --log logs/erm/cityscapes2foggy # Oracle # Oracle Results on Cityscapes CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \ -s Cityscapes -t Cityscapes --log logs/oracle/cityscapes # Oracle Results on Foggy Cityscapes CUDA_VISIBLE_DEVICES=0 python erm.py data/Cityscapes data/Cityscapes \ -s FoggyCityscapes -t FoggyCityscapes --log logs/oracle/foggy_cityscapes ================================================ FILE: examples/domain_adaptation/semantic_segmentation/fda.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import sys import argparse from PIL import Image import numpy as np import os import math import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader sys.path.append('../../..') from tllib.translation.fourier_transform import FourierTransform import tllib.vision.models.segmentation as models import tllib.vision.datasets.segmentation as datasets import tllib.vision.transforms.segmentation as T from tllib.vision.transforms import DeNormalizeAndTranspose from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter, Meter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def robust_entropy(y, ita=1.5, num_classes=19, reduction='mean'): """ Robust entropy proposed in `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) `_ Args: y (tensor): logits output of segmentation model in shape of :math:`(N, C, H, W)` ita (float, optional): parameters for robust entropy. Default: 1.5 num_classes (int, optional): number of classes. Default: 19 reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output. Default: ``'mean'`` Returns: Scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, )`. """ P = F.softmax(y, dim=1) logP = F.log_softmax(y, dim=1) PlogP = P * logP ent = -1.0 * PlogP.sum(dim=1) ent = ent / math.log(num_classes) # compute robust entropy ent = ent ** 2.0 + 1e-8 ent = ent ** ita if reduction == 'mean': return ent.mean() else: return ent def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code target_dataset = datasets.__dict__[args.target] train_target_dataset = target_dataset( root=args.target_root, transforms=T.Compose([ T.Resize(image_size=args.train_size), T.NormalizeAndTranspose(), ]), ) train_target_loader = DataLoader(train_target_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) val_target_dataset = target_dataset( root=args.target_root, split='val', transforms=T.Compose([ T.Resize(image_size=args.test_input_size, label_size=args.test_output_size), T.NormalizeAndTranspose(), ]) ) val_target_loader = DataLoader(val_target_dataset, batch_size=1, shuffle=False, pin_memory=True) # collect the absolute paths of all images in the target dataset target_image_list = train_target_dataset.collect_image_paths() # build a fourier transform that translate source images to the target style fourier_transform = T.wrapper(FourierTransform)(target_image_list, os.path.join(logger.root, "amplitudes"), rebuild=False, beta=args.beta) source_dataset = datasets.__dict__[args.source] train_source_dataset = source_dataset( root=args.source_root, transforms=T.Compose([ T.Resize((2048, 1024)), # convert source image to the size of the target image before fourier transform fourier_transform, T.RandomResizedCrop(size=args.train_size, ratio=args.resize_ratio, scale=(0.5, 1.)), T.ColorJitter(brightness=0.3, contrast=0.3), T.RandomHorizontalFlip(), T.NormalizeAndTranspose(), ]), ) train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True, drop_last=True) train_source_iter = ForeverDataIterator(train_source_loader) train_target_iter = ForeverDataIterator(train_target_loader) # create model model = models.__dict__[args.arch](num_classes=train_source_dataset.num_classes).to(device) # define optimizer and lr scheduler optimizer = SGD(model.get_parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. - float(x) / args.epochs / args.iters_per_epoch) ** (args.lr_power)) # optionally resume from a checkpoint if args.resume: checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) args.start_epoch = checkpoint['epoch'] + 1 # define loss function (criterion) criterion = torch.nn.CrossEntropyLoss(ignore_index=args.ignore_label).to(device) interp_train = nn.Upsample(size=args.train_size[::-1], mode='bilinear', align_corners=True) interp_val = nn.Upsample(size=args.test_output_size[::-1], mode='bilinear', align_corners=True) # define visualization function decode = train_source_dataset.decode_target def visualize(image, pred, label, prefix): """ Args: image (tensor): 3 x H x W pred (tensor): C x H x W label (tensor): H x W prefix: prefix of the saving image """ image = image.detach().cpu().numpy() pred = pred.detach().max(dim=0)[1].cpu().numpy() label = label.cpu().numpy() for tensor, name in [ (Image.fromarray(np.uint8(DeNormalizeAndTranspose()(image))), "image"), (decode(label), "label"), (decode(pred), "pred") ]: tensor.save(logger.get_image_path("{}_{}.png".format(prefix, name))) if args.phase == 'test': confmat = validate(val_target_loader, model, interp_val, criterion, visualize, args) print(confmat) return # start training best_iou = 0. for epoch in range(args.start_epoch, args.epochs): logger.set_epoch(epoch) print(lr_scheduler.get_lr()) # train for one epoch train(train_source_iter, train_target_iter, model, interp_train, criterion, optimizer, lr_scheduler, epoch, visualize if args.debug else None, args) # evaluate on validation set confmat = validate(val_target_loader, model, interp_val, criterion, None, args) print(confmat.format(train_source_dataset.classes)) acc_global, acc, iu = confmat.compute() # calculate the mean iou over partial classes indexes = [train_source_dataset.classes.index(name) for name in train_source_dataset.evaluate_classes] iu = iu[indexes] mean_iou = iu.mean() # remember best acc@1 and save checkpoint torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, 'args': args }, logger.get_checkpoint_path(epoch) ) if mean_iou > best_iou: shutil.copy(logger.get_checkpoint_path(epoch), logger.get_checkpoint_path('best')) best_iou = max(best_iou, mean_iou) print("Target: {} Best: {}".format(mean_iou, best_iou)) logger.close() def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, model, interp, criterion, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_t = AverageMeter('Loss (t)', ':3.2f') losses_entropy_t = AverageMeter('Entropy (t)', ':3.2f') accuracies_s = Meter('Acc (s)', ':3.2f') accuracies_t = Meter('Acc (t)', ':3.2f') iou_s = Meter('IoU (s)', ':3.2f') iou_t = Meter('IoU (t)', ':3.2f') confmat_s = ConfusionMatrix(model.num_classes) confmat_t = ConfusionMatrix(model.num_classes) progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_s, losses_t, losses_entropy_t, accuracies_s, accuracies_t, iou_s, iou_t], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): optimizer.zero_grad() x_s, label_s = next(train_source_iter) x_t, label_t = next(train_target_iter) x_s = x_s.to(device) label_s = label_s.long().to(device) x_t = x_t.to(device) label_t = label_t.long().to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s = model(x_s) pred_s = interp(y_s) loss_cls_s = criterion(pred_s, label_s) loss_cls_s.backward() y_t = model(x_t) pred_t = interp(y_t) loss_cls_t = criterion(pred_t, label_t) loss_entropy_t = robust_entropy(y_t, args.ita) (args.entropy_weight * loss_entropy_t).backward() # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure accuracy and record loss losses_s.update(loss_cls_s.item(), x_s.size(0)) losses_t.update(loss_cls_t.item(), x_s.size(0)) losses_entropy_t.update(loss_entropy_t.item(), x_s.size(0)) confmat_s.update(label_s.flatten(), pred_s.argmax(1).flatten()) confmat_t.update(label_t.flatten(), pred_t.argmax(1).flatten()) acc_global_s, acc_s, iu_s = confmat_s.compute() acc_global_t, acc_t, iu_t = confmat_t.compute() accuracies_s.update(acc_s.mean().item()) accuracies_t.update(acc_t.mean().item()) iou_s.update(iu_s.mean().item()) iou_t.update(iu_t.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x_s[0], pred_s[0], label_s[0], "source_{}".format(i)) visualize(x_t[0], pred_t[0], label_t[0], "target_{}".format(i)) def validate(val_loader: DataLoader, model, interp, criterion, visualize, args: argparse.Namespace): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') acc = Meter('Acc', ':3.2f') iou = Meter('IoU', ':3.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, acc, iou], prefix='Test: ') # switch to evaluate mode model.eval() confmat = ConfusionMatrix(model.num_classes) with torch.no_grad(): end = time.time() for i, (x, label) in enumerate(val_loader): x = x.to(device) label = label.long().to(device) # compute output output = interp(model(x)) loss = criterion(output, label) # measure accuracy and record loss losses.update(loss.item(), x.size(0)) confmat.update(label.flatten(), output.argmax(1).flatten()) acc_global, accs, iu = confmat.compute() acc.update(accs.mean().item()) iou.update(iu.mean().item()) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(x[0], output[0], label[0], "val_{}".format(i)) return confmat if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description='FDA for Segmentation Domain Adaptation') # dataset parameters parser.add_argument('source_root', help='root path of the source dataset') parser.add_argument('target_root', help='root path of the target dataset') parser.add_argument('-s', '--source', help='source domain(s)') parser.add_argument('-t', '--target', help='target domain(s)') parser.add_argument('--resize-ratio', nargs='+', type=float, default=(1.5, 8 / 3.), help='the resize ratio for the random resize crop') parser.add_argument('--train-size', nargs='+', type=int, default=(1024, 512), help='the input and output image size during training') parser.add_argument('--test-input-size', nargs='+', type=int, default=(1024, 512), help='the input image size during test') parser.add_argument('--test-output-size', nargs='+', type=int, default=(2048, 1024), help='the output image size during test') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='deeplabv2_resnet101', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: deeplabv2_resnet101)') parser.add_argument("--entropy-weight", type=float, default=0., help="weight for entropy") parser.add_argument("--ita", type=float, default=2.0, help="ita for robust entropy") parser.add_argument("--beta", type=int, default=1, help="beta for FDA") parser.add_argument("--resume", type=str, default=None, help="Where restore model parameters from.") # training parameters parser.add_argument('-b', '--batch-size', default=2, type=int, metavar='N', help='mini-batch size (default: 2)') parser.add_argument('--lr', '--learning-rate', default=2.5e-3, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument("--momentum", type=float, default=0.9, help="Momentum component of the optimiser.") parser.add_argument("--weight-decay", type=float, default=0.0005, help="Regularisation parameter for L2-loss.") parser.add_argument("--lr-power", type=float, default=0.9, help="Decay parameter to compute the learning rate (only for deeplab).") parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('-i', '--iters-per-epoch', default=2500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--ignore-label", type=int, default=255, help="The index of the label to ignore during the training.") parser.add_argument("--log", type=str, default='fda', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") parser.add_argument('--debug', action="store_true", help='In the debug mode, save images and predictions during training') args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/semantic_segmentation/fda.sh ================================================ # GTA5 to Cityscapes CUDA_VISIBLE_DEVICES=0 python fda.py data/GTA5 data/Cityscapes -s GTA5 -t Cityscapes \ --log logs/fda/gtav2cityscapes --debug # Synthia to Cityscapes CUDA_VISIBLE_DEVICES=0 python fda.py data/synthia data/Cityscapes -s Synthia -t Cityscapes \ --log logs/fda/synthia2cityscapes --debug # Cityscapes to FoggyCityscapes CUDA_VISIBLE_DEVICES=0 python fda.py data/Cityscapes data/Cityscapes -s Cityscapes -t FoggyCityscapes \ --log logs/fda/cityscapes2foggy --debug ================================================ FILE: examples/domain_adaptation/wilds_image_classification/README.md ================================================ # Unsupervised Domain Adaptation for WILDS (Image Classification) ## Installation It’s suggested to use **pytorch==1.9.0** in order to reproduce the benchmark results. You need to install **apex** following ``https://github.com/NVIDIA/apex``. Then run ``` pip install -r requirements.txt ``` ## Dataset Following datasets can be downloaded automatically: - [DomainNet](http://ai.bu.edu/M3SDA/) - [iwildcam (WILDS)](https://wilds.stanford.edu/datasets/) - [camelyon17 (WILDS)](https://wilds.stanford.edu/datasets/) - [fmow (WILDS)](https://wilds.stanford.edu/datasets/) ## Supported Methods Supported methods include: - [Domain Adversarial Neural Network (DANN)](https://arxiv.org/abs/1505.07818) - [Deep Adaptation Network (DAN)](https://arxiv.org/pdf/1502.02791) - [Joint Adaptation Network (JAN)](https://arxiv.org/abs/1605.06636) - [Conditional Domain Adversarial Network (CDAN)](https://arxiv.org/abs/1705.10667) - [Margin Disparity Discrepancy (MDD)](https://arxiv.org/abs/1904.05801) - [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence (FixMatch)](https://arxiv.org/abs/2001.07685) ## Usage Our code is based on [https://github.com/NVIDIA/apex/edit/master/examples/imagenet](https://github.com/NVIDIA/apex/edit/master/examples/imagenet) . It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and VGG, on the WILDS dataset. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). The shell files give all the training scripts we use, e.g., ``` CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121 ``` ## Results ### Performance on WILDS-FMoW (DenseNet-121) | Methods | Val Avg Acc | Test Avg Acc | Val Worst-region Acc | Test Worst-region Acc | |---------|-------------|--------------|----------------------|-----------------------| | ERM | 59.8 | 53.3 | 50.2 | 32.2 | | DANN | 60.6 | 54.2 | 49.1 | 34.8 | | DAN | 61.7 | 55.5 | 48.3 | 35.3 | | JAN | 61.5 | 55.3 | 50.6 | 36.3 | | CDAN | 60.7 | 55.0 | 47.4 | 35.5 | | MDD | 60.1 | 55.1 | 49.3 | 35.9 | | FixMatch| 61.1 | 55.1 | 51.8 | 37.4 | ### Performance on WILDS-IWildCAM (ResNet50) | Methods | Val Avg Acc | Test Avg Acc | Val F1 macro | Test F1 macro | |---------|-------------|--------------|--------------|---------------| | ERM | 59.9 | 72.6 | 36.3 | 32.9 | | DANN | 57.4 | 70.1 | 35.8 | 32.2 | | DAN | 63.7 | 69.4 | 39.1 | 31.6 | | JAN | 62.4 | 68.7 | 37.6 | 31.5 | | CDAN | 57.6 | 71.2 | 37.0 | 30.6 | | MDD | 58.3 | 73.5 | 35.0 | 30.0 | ### Visualization We use tensorboard to record the training process and visualize the outputs of the models. ``` tensorboard --logdir=logs ``` ### Distributed training We uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process. ``` CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 --master_port 6666 erm.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 -j 8 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121_bs_128 ``` ## TODO 1. update experiment results 2. support DomainNet 3. support camelyon17 4. support self-training methods 5. support self-supervised methods ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{DANN, author = {Ganin, Yaroslav and Lempitsky, Victor}, Booktitle = {ICML}, Title = {Unsupervised domain adaptation by backpropagation}, Year = {2015} } @inproceedings{DAN, author = {Mingsheng Long and Yue Cao and Jianmin Wang and Michael I. Jordan}, title = {Learning Transferable Features with Deep Adaptation Networks}, booktitle = {ICML}, year = {2015}, } @inproceedings{JAN, title={Deep transfer learning with joint adaptation networks}, author={Long, Mingsheng and Zhu, Han and Wang, Jianmin and Jordan, Michael I}, booktitle={ICML}, year={2017}, } @inproceedings{CDAN, author = {Mingsheng Long and Zhangjie Cao and Jianmin Wang and Michael I. Jordan}, title = {Conditional Adversarial Domain Adaptation}, booktitle = {NeurIPS}, year = {2018} } @inproceedings{MDD, title={Bridging theory and algorithm for domain adaptation}, author={Zhang, Yuchen and Liu, Tianle and Long, Mingsheng and Jordan, Michael}, booktitle={ICML}, year={2019}, } @inproceedings{FixMatch, title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence}, author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang}, booktitle={NIPS}, year={2020} } ``` ================================================ FILE: examples/domain_adaptation/wilds_image_classification/cdan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import argparse import os import shutil import time import pprint import math from itertools import cycle import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import torchvision.models as models from torch.utils.tensorboard import SummaryWriter from timm.loss.cross_entropy import LabelSmoothingCrossEntropy import wilds try: from apex.parallel import DistributedDataParallel as DDP from apex.fp16_utils import * from apex import amp, optimizers from apex.multi_tensor_apply import multi_tensor_applier except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") import utils from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.alignment.cdan import ConditionalDomainAdversarialLoss, ImageClassifier as Classifier from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter from tllib.utils.metric import accuracy def main(args): writer = None if args.local_rank == 0: logger = CompleteLogger(args.log, args.phase) if args.phase == 'train': writer = SummaryWriter(args.log) pprint.pprint(args) print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True best_prec1 = 0 if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format # Data loading code train_transform = utils.get_train_transform( img_size=args.img_size, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.interpolation, ) val_transform = utils.get_val_transform( img_size=args.img_size, crop_pct=args.crop_pct, interpolation=args.interpolation, ) if args.local_rank == 0: print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \ utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list, train_transform, val_transform, verbose=args.local_rank == 0) # create model if args.local_rank == 0: if not args.scratch: print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch) features_dim = model.features_dim if args.randomized: domain_discri = DomainDiscriminator(args.randomized_dim, hidden_size=1024, sigmoid=False) else: domain_discri = DomainDiscriminator(features_dim * args.num_classes, hidden_size=1024, sigmoid=False) if args.sync_bn: import apex if args.local_rank == 0: print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda().to(memory_format=memory_format) domain_discri = domain_discri.cuda().to(memory_format=memory_format) # Scale learning rate based on global batch size args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256. optimizer = torch.optim.SGD( model.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. (model, domain_discri), optimizer = amp.initialize([model, domain_discri], optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale ) # Use cosine annealing learning rate strategy lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr) ) # define loss function domain_adv = ConditionalDomainAdversarialLoss( domain_discri, num_classes=args.num_classes, features_dim=features_dim, randomized=args.randomized, randomized_dim=args.randomized_dim, sigmoid=False ) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) domain_adv = DDP(domain_adv, delay_allreduce=True) # define loss function (criterion) if args.smoothing: criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda() else: criterion = nn.CrossEntropyLoss().cuda() # Data loading code train_labeled_sampler = None train_unlabeled_sampler = None if args.distributed: train_labeled_sampler = DistributedSampler(train_labeled_dataset) train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True) train_unlabeled_loader = DataLoader( train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True) if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) utils.validate(d, model, -1, writer, args) return for epoch in range(args.epochs): if args.distributed: train_labeled_sampler.set_epoch(epoch) train_unlabeled_sampler.set_epoch(epoch) lr_scheduler.step(epoch) if args.local_rank == 0: print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) prec1 = utils.validate(d, model, epoch, writer, args) # remember best prec@1 and save checkpoint if args.local_rank == 0: is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) def train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_trans = AverageMeter('Loss (transfer)', ':3.2f') domain_accs = AverageMeter('Domain Acc', ':3.1f') top1 = AverageMeter('Top 1', ':3.1f') # switch to train mode model.train() end = time.time() num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader)) for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \ zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)): # compute output n_s, n_t = len(input_s), len(input_t) input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0) output, feature = model(input) output_s, output_t = output.split([n_s, n_t], dim=0) feature_s, feature_t = feature.split([n_s, n_t], dim=0) loss_s = criterion(output_s, target_s.cuda()) loss_trans = domain_adv(output_s, feature_s, output_t, feature_t) loss = loss_s + loss_trans * args.trade_off # compute gradient and do SGD step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,)) # Average loss and accuracy across processes for logging if args.distributed: reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size) reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size) prec1 = utils.reduce_tensor(prec1, args.world_size) domain_acc = domain_adv.module.domain_discriminator_accuracy else: reduced_loss_s = loss_s.data reduced_loss_trans = loss_trans.data domain_acc = domain_adv.domain_discriminator_accuracy # to_python_float incurs a host<->device sync losses_s.update(to_python_float(reduced_loss_s), input_s.size(0)) losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0)) domain_accs.update(to_python_float(domain_acc), input_s.size(0)) top1.update(to_python_float(prec1), input_s.size(0)) global_step = epoch * num_iterations + i torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq) end = time.time() if args.local_rank == 0: writer.add_scalar('train/top1', to_python_float(prec1), global_step) writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step) writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step) writer.add_figure('train/predictions vs. actuals', utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names, metadata_s, train_labeled_loader.dataset.metadata_map), global_step=global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t' 'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t' 'Domain Acc {domain_acc.val:.10f} ({domain_acc.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, i, len(train_labeled_loader), args.world_size * args.batch_size[0] / batch_time.val, args.world_size * args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans, domain_acc=domain_accs, top1=top1)) if __name__ == '__main__': model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='CDAN') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: fmow)') parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ]) parser.add_argument('--test-list', nargs='+', default=["val", "test"]) parser.add_argument('--metric', default="acc_worst_region") parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', help='Random resize scale (default: 0.5 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--bottleneck-dim', default=512, type=int, help='Dimension of bottleneck') parser.add_argument('-r', '--randomized', action='store_true', help='using randomized multi-linear-map (default: False)') parser.add_argument('-rd', '--randomized-dim', default=1024, type=int, help='randomized dimension when using randomized multi-linear-map (default: 1024)') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Initial learning rate. Will be scaled by /256: ' 'args.lr = args.lr*float(args.batch_size*args.world_size)/256.') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int) parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') parser.add_argument('--opt-level', type=str) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--channels-last', type=bool, default=False) parser.add_argument("--log", type=str, default='cdan', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_image_classification/cdan.sh ================================================ CUDA_VISIBLE_DEVICES=0 python cdan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/cdan/fmow/lr_0_1_aa_v0_densenet121 CUDA_VISIBLE_DEVICES=0 python cdan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \ --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ --log logs/cdan/iwildcam/lr_1_deterministic ================================================ FILE: examples/domain_adaptation/wilds_image_classification/dan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import argparse import os import shutil import time import pprint import math from itertools import cycle import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import torchvision.models as models from torch.utils.tensorboard import SummaryWriter from timm.loss.cross_entropy import LabelSmoothingCrossEntropy import wilds try: from apex.parallel import DistributedDataParallel as DDP from apex.fp16_utils import * from apex import amp, optimizers from apex.multi_tensor_apply import multi_tensor_applier except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") import utils from tllib.alignment.dan import MultipleKernelMaximumMeanDiscrepancy, ImageClassifier as Classifier from tllib.modules.kernels import GaussianKernel from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter from tllib.utils.metric import accuracy def main(args): writer = None if args.local_rank == 0: logger = CompleteLogger(args.log, args.phase) if args.phase == 'train': writer = SummaryWriter(args.log) pprint.pprint(args) print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True best_prec1 = 0 if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format # Data loading code train_transform = utils.get_train_transform( img_size=args.img_size, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.interpolation, ) val_transform = utils.get_val_transform( img_size=args.img_size, crop_pct=args.crop_pct, interpolation=args.interpolation, ) if args.local_rank == 0: print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \ utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list, train_transform, val_transform, verbose=args.local_rank == 0) # create model if args.local_rank == 0: if not args.scratch: print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch) if args.sync_bn: import apex if args.local_rank == 0: print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda().to(memory_format=memory_format) # Scale learning rate based on global batch size args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256. optimizer = torch.optim.SGD( model.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale ) # Use cosine annealing learning rate strategy lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr) ) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) # define loss function (criterion) if args.smoothing: criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda() else: criterion = nn.CrossEntropyLoss().cuda() # Data loading code train_labeled_sampler = None train_unlabeled_sampler = None if args.distributed: train_labeled_sampler = DistributedSampler(train_labeled_dataset) train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True) train_unlabeled_loader = DataLoader( train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True) if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) utils.validate(d, model, -1, writer, args) return # define loss function mkmmd_loss = MultipleKernelMaximumMeanDiscrepancy( kernels=[GaussianKernel(alpha=2 ** k) for k in range(-3, 2)], linear=not args.non_linear ) for epoch in range(args.epochs): if args.distributed: train_labeled_sampler.set_epoch(epoch) train_unlabeled_sampler.set_epoch(epoch) lr_scheduler.step(epoch) print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, train_unlabeled_loader, model, criterion, mkmmd_loss, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) prec1 = utils.validate(d, model, epoch, writer, args) # remember best prec@1 and save checkpoint if args.local_rank == 0: is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) def train(train_labeled_loader, train_unlabeled_loader, model, criterion, mkmmd_loss, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_trans = AverageMeter('Loss (transfer)', ':3.2f') top1 = AverageMeter('Top 1', ':3.1f') # switch to train mode model.train() end = time.time() num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader)) for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \ zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)): # compute output n_s, n_t = len(input_s), len(input_t) input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0) output, feature = model(input) output_s, output_t = output.split([n_s, n_t], dim=0) feature_s, feature_t = feature.split([n_s, n_t], dim=0) loss_s = criterion(output_s, target_s.cuda()) loss_trans = mkmmd_loss(feature_s, feature_t) loss = loss_s + loss_trans * args.trade_off # compute gradient and do SGD step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,)) # Average loss and accuracy across processes for logging if args.distributed: reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size) reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size) prec1 = utils.reduce_tensor(prec1, args.world_size) else: reduced_loss_s = loss_s.data reduced_loss_trans = loss_trans.data # to_python_float incurs a host<->device sync losses_s.update(to_python_float(reduced_loss_s), input_s.size(0)) losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0)) top1.update(to_python_float(prec1), input_s.size(0)) global_step = epoch * num_iterations + i torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq) end = time.time() if args.local_rank == 0: writer.add_scalar('train/top1', to_python_float(prec1), global_step) writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step) writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step) writer.add_figure('train/predictions vs. actuals', utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names, metadata_s, train_labeled_loader.dataset.metadata_map), global_step=global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t' 'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, i, len(train_labeled_loader), args.world_size * args.batch_size[0] / batch_time.val, args.world_size * args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans, top1=top1)) if __name__ == '__main__': model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='DAN') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: fmow)') parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ]) parser.add_argument('--test-list', nargs='+', default=["val", "test"]) parser.add_argument('--metric', default="acc_worst_region") parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', help='Random resize scale (default: 0.5 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--bottleneck-dim', default=512, type=int, help='Dimension of bottleneck') parser.add_argument('--non-linear', default=False, action='store_true', help='whether not use the linear version') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Initial learning rate. Will be scaled by /256: ' 'args.lr = args.lr*float(args.batch_size*args.world_size)/256.') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int) parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') parser.add_argument('--opt-level', type=str) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--channels-last', type=bool, default=False) parser.add_argument("--log", type=str, default='dan', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_image_classification/dan.sh ================================================ CUDA_VISIBLE_DEVICES=0 python dan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/dan/fmow/lr_0_1_aa_v0_densenet121 CUDA_VISIBLE_DEVICES=0 python dan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \ --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ --log logs/dan/iwildcam/lr_0_3_deterministic ================================================ FILE: examples/domain_adaptation/wilds_image_classification/dann.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import argparse import os import shutil import time import pprint import math from itertools import cycle import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import torchvision.models as models from torch.utils.tensorboard import SummaryWriter from timm.loss.cross_entropy import LabelSmoothingCrossEntropy import wilds try: from apex.parallel import DistributedDataParallel as DDP from apex.fp16_utils import * from apex import amp, optimizers from apex.multi_tensor_apply import multi_tensor_applier except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") import utils from tllib.modules.domain_discriminator import DomainDiscriminator from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier as Classifier from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter from tllib.utils.metric import accuracy def main(args): writer = None if args.local_rank == 0: logger = CompleteLogger(args.log, args.phase) if args.phase == 'train': writer = SummaryWriter(args.log) pprint.pprint(args) print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True best_prec1 = 0 if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format # Data loading code train_transform = utils.get_train_transform( img_size=args.img_size, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.interpolation, ) val_transform = utils.get_val_transform( img_size=args.img_size, crop_pct=args.crop_pct, interpolation=args.interpolation, ) if args.local_rank == 0: print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \ utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list, train_transform, val_transform, verbose=args.local_rank == 0) # create model if args.local_rank == 0: if not args.scratch: print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch) features_dim = model.features_dim domain_discri = DomainDiscriminator(features_dim, hidden_size=1024, sigmoid=False) if args.sync_bn: import apex if args.local_rank == 0: print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda().to(memory_format=memory_format) domain_discri = domain_discri.cuda().to(memory_format=memory_format) # Scale learning rate based on global batch size args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256. optimizer = torch.optim.SGD( model.get_parameters() + domain_discri.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. (model, domain_discri), optimizer = amp.initialize([model, domain_discri], optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale ) # Use cosine annealing learning rate strategy lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr) ) # define loss function domain_adv = DomainAdversarialLoss(domain_discri, sigmoid=False) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) domain_adv = DDP(domain_adv, delay_allreduce=True) # define loss function (criterion) if args.smoothing: criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda() else: criterion = nn.CrossEntropyLoss().cuda() # Data loading code train_labeled_sampler = None train_unlabeled_sampler = None if args.distributed: train_labeled_sampler = DistributedSampler(train_labeled_dataset) train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True) train_unlabeled_loader = DataLoader( train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True) if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) utils.validate(d, model, -1, writer, args) return for epoch in range(args.epochs): if args.distributed: train_labeled_sampler.set_epoch(epoch) train_unlabeled_sampler.set_epoch(epoch) lr_scheduler.step(epoch) if args.local_rank == 0: print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) prec1 = utils.validate(d, model, epoch, writer, args) # remember best prec@1 and save checkpoint if args.local_rank == 0: is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) def train(train_labeled_loader, train_unlabeled_loader, model, criterion, domain_adv, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_trans = AverageMeter('Loss (transfer)', ':3.2f') domain_accs = AverageMeter('Domain Acc', ':3.1f') top1 = AverageMeter('Top 1', ':3.1f') # switch to train mode model.train() end = time.time() num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader)) for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \ zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)): # compute output n_s, n_t = len(input_s), len(input_t) input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0) output, feature = model(input) output_s, output_t = output.split([n_s, n_t], dim=0) feature_s, feature_t = feature.split([n_s, n_t], dim=0) loss_s = criterion(output_s, target_s.cuda()) loss_trans = domain_adv(feature_s, feature_t) loss = loss_s + loss_trans * args.trade_off # compute gradient and do SGD step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,)) # Average loss and accuracy across processes for logging if args.distributed: reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size) reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size) prec1 = utils.reduce_tensor(prec1, args.world_size) domain_acc = domain_adv.module.domain_discriminator_accuracy else: reduced_loss_s = loss_s.data reduced_loss_trans = loss_trans.data domain_acc = domain_adv.domain_discriminator_accuracy # to_python_float incurs a host<->device sync losses_s.update(to_python_float(reduced_loss_s), input_s.size(0)) losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0)) domain_accs.update(to_python_float(domain_acc), input_s.size(0)) top1.update(to_python_float(prec1), input_s.size(0)) global_step = epoch * num_iterations + i torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq) end = time.time() if args.local_rank == 0: writer.add_scalar('train/top1', to_python_float(prec1), global_step) writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step) writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step) writer.add_figure('train/predictions vs. actuals', utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names, metadata_s, train_labeled_loader.dataset.metadata_map), global_step=global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t' 'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t' 'Domain Acc {domain_acc.val:.10f} ({domain_acc.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, i, len(train_labeled_loader), args.world_size * args.batch_size[0] / batch_time.val, args.world_size * args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans, domain_acc=domain_accs, top1=top1)) if __name__ == '__main__': model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='DANN') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: fmow)') parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ]) parser.add_argument('--test-list', nargs='+', default=["val", "test"]) parser.add_argument('--metric', default="acc_worst_region") parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', help='Random resize scale (default: 0.5 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--bottleneck-dim', default=512, type=int, help='Dimension of bottleneck') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Initial learning rate. Will be scaled by /256: ' 'args.lr = args.lr*float(args.batch_size*args.world_size)/256.') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int) parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') parser.add_argument('--opt-level', type=str) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--channels-last', type=bool, default=False) parser.add_argument("--log", type=str, default='dann', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_image_classification/dann.sh ================================================ CUDA_VISIBLE_DEVICES=0 python dann.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/dann/fmow/lr_0_1_aa_v0_densenet121 CUDA_VISIBLE_DEVICES=0 python dann.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \ --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ --log logs/dann/iwildcam/lr_1_deterministic ================================================ FILE: examples/domain_adaptation/wilds_image_classification/erm.py ================================================ """ Adapted from https://github.com/NVIDIA/apex/tree/master/examples @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import argparse import os import shutil import time import pprint import math import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import torchvision.models as models from torch.utils.tensorboard import SummaryWriter from timm.loss.cross_entropy import LabelSmoothingCrossEntropy import wilds try: from apex.parallel import DistributedDataParallel as DDP from apex.fp16_utils import * from apex import amp, optimizers from apex.multi_tensor_apply import multi_tensor_applier except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") import utils from tllib.modules.classifier import Classifier from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter from tllib.utils.metric import accuracy def main(args): writer = None if args.local_rank == 0: logger = CompleteLogger(args.log, args.phase) if args.phase == 'train': writer = SummaryWriter(args.log) pprint.pprint(args) print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True best_prec1 = 0 if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format # Data loading code train_transform = utils.get_train_transform( img_size=args.img_size, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.interpolation, ) val_transform = utils.get_val_transform( img_size=args.img_size, crop_pct=args.crop_pct, interpolation=args.interpolation, ) if args.local_rank == 0: print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \ utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list, train_transform, val_transform, verbose=args.local_rank == 0) # create model if args.local_rank == 0: if not args.scratch: print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None model = Classifier(backbone, args.num_classes, pool_layer=pool_layer, finetune=not args.scratch) if args.sync_bn: import apex if args.local_rank == 0: print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda().to(memory_format=memory_format) # Scale learning rate based on global batch size args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256. optimizer = torch.optim.SGD( model.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale ) # Use cosine annealing learning rate strategy lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr) ) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) # define loss function (criterion) if args.smoothing: criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda() else: criterion = nn.CrossEntropyLoss().cuda() # Data loading code train_labeled_sampler = None train_unlabeled_sampler = None if args.distributed: train_labeled_sampler = DistributedSampler(train_labeled_dataset) train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler) train_unlabeled_loader = DataLoader( train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler) if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) utils.validate(d, model, -1, writer, args) return for epoch in range(args.epochs): if args.distributed: train_labeled_sampler.set_epoch(epoch) train_unlabeled_sampler.set_epoch(epoch) lr_scheduler.step(epoch) if args.local_rank == 0: print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) prec1 = utils.validate(d, model, epoch, writer, args) # remember best prec@1 and save checkpoint if args.local_rank == 0: is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) def train(train_loader, model, criterion, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses = AverageMeter('Loss', ':3.2f') top1 = AverageMeter('Top 1', ':3.1f') # switch to train mode model.train() end = time.time() for i, (input, target, metadata) in enumerate(train_loader): # compute output output, _ = model(input.cuda()) loss = criterion(output, target.cuda()) # compute gradient and do SGD step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy prec1, = accuracy(output.data, target.cuda(), topk=(1,)) # Average loss and accuracy across processes for logging if args.distributed: reduced_loss = utils.reduce_tensor(loss.data, args.world_size) prec1 = utils.reduce_tensor(prec1, args.world_size) else: reduced_loss = loss.data # to_python_float incurs a host<->device sync losses.update(to_python_float(reduced_loss), input.size(0)) top1.update(to_python_float(prec1), input.size(0)) global_step = epoch * len(train_loader) + i torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq) end = time.time() if args.local_rank == 0: writer.add_scalar('train/top1', to_python_float(prec1), global_step) writer.add_scalar("train/loss", to_python_float(reduced_loss), global_step) writer.add_figure('train/predictions vs. actuals', utils.plot_classes_preds(input.cpu(), target, output.cpu(), args.class_names, metadata, train_loader.dataset.metadata_map), global_step=global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, i, len(train_loader), args.world_size * args.batch_size[0] / batch_time.val, args.world_size * args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss=losses, top1=top1)) if __name__ == '__main__': model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='Src Only') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: fmow)') parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ]) parser.add_argument('--test-list', nargs='+', default=["val", "test"]) parser.add_argument('--metric', default="acc_worst_region") parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', help='Random resize scale (default: 0.5 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Initial learning rate. Will be scaled by /256: ' 'args.lr = args.lr*float(args.batch_size*args.world_size)/256.') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int) parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') parser.add_argument('--opt-level', type=str) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--channels-last', type=bool, default=False) parser.add_argument("--log", type=str, default='src_only', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_image_classification/erm.sh ================================================ CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/erm/fmow/lr_0_1_aa_v0_densenet121 CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 1 --opt-level O1 \ --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 -p 500 --metric "F1-macro_all" \ --log logs/erm/iwildcam/lr_1_deterministic ================================================ FILE: examples/domain_adaptation/wilds_image_classification/fixmatch.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import argparse import os import shutil import time import pprint import math from itertools import cycle import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import torchvision.models as models from torch.utils.tensorboard import SummaryWriter from timm.loss.cross_entropy import LabelSmoothingCrossEntropy import wilds try: from apex.parallel import DistributedDataParallel as DDP from apex.fp16_utils import * from apex import amp, optimizers from apex.multi_tensor_apply import multi_tensor_applier except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") import utils from tllib.modules.classifier import Classifier from tllib.vision.transforms import MultipleApply from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter from tllib.utils.metric import accuracy class ImageClassifier(Classifier): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=512, **kwargs): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) def forward(self, x: torch.Tensor): """""" f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) predictions = self.head(f) return predictions def main(args): writer = None if args.local_rank == 0: logger = CompleteLogger(args.log, args.phase) if args.phase == 'train': writer = SummaryWriter(args.log) pprint.pprint(args) print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True best_prec1 = 0 if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format # Data loading code weak_transform = utils.get_train_transform( img_size=args.img_size, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=None, auto_augment=None, interpolation=args.interpolation, ) strong_transform = utils.get_train_transform( img_size=args.img_size, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.interpolation, ) train_source_transform = strong_transform train_target_transform = MultipleApply([weak_transform, strong_transform]) val_transform = utils.get_val_transform( img_size=args.img_size, crop_pct=args.crop_pct, interpolation=args.interpolation, ) if args.local_rank == 0: print("train_source_transform: ", train_source_transform) print('train_target_transform: ', train_target_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \ utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list, train_source_transform, val_transform, verbose=args.local_rank == 0, transform_train_target=train_target_transform) # create model if args.local_rank == 0: if not args.scratch: print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None model = ImageClassifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch) if args.sync_bn: import apex if args.local_rank == 0: print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda().to(memory_format=memory_format) # Scale learning rate based on global batch size args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256. optimizer = torch.optim.SGD( model.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale ) # Use cosine annealing learning rate strategy lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr) ) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) # define loss function (criterion) if args.smoothing: criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda() else: criterion = nn.CrossEntropyLoss().cuda() # Data loading code train_labeled_sampler = None train_unlabeled_sampler = None if args.distributed: train_labeled_sampler = DistributedSampler(train_labeled_dataset) train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True) train_unlabeled_loader = DataLoader( train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True) if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) utils.validate(d, model, -1, writer, args) return for epoch in range(args.epochs): if args.distributed: train_labeled_sampler.set_epoch(epoch) train_unlabeled_sampler.set_epoch(epoch) lr_scheduler.step(epoch) if args.local_rank == 0: print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, train_unlabeled_loader, model, criterion, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) prec1 = utils.validate(d, model, epoch, writer, args) # remember best prec@1 and save checkpoint if args.local_rank == 0: is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) def train(train_labeled_loader, train_unlabeled_loader, model, criterion, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_self_training = AverageMeter('Loss (self training)', ':3.2f') top1 = AverageMeter('Top 1', ':3.1f') # switch to train mode model.train() end = time.time() num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader)) for i, (input_s, target_s, metadata_s), ((input_t, input_t_strong), metadata_t) in \ zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)): # compute output n_s, n_t = len(input_s), len(input_t) with torch.no_grad(): output_t = model(input_t.cuda()) confidence, pseudo_labels = F.softmax(output_t, dim=1).max(dim=1) mask = (confidence > args.threshold).float() input = torch.cat([input_s.cuda(), input_t_strong.cuda()], dim=0) output = model(input) output_s, output_t_strong = output.split([n_s, n_t], dim=0) loss_s = criterion(output_s, target_s.cuda()) loss_self_training = args.trade_off * \ (F.cross_entropy(output_t_strong, pseudo_labels, reduction='none') * mask).mean() loss = loss_s + loss_self_training # compute gradient and do SGD step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,)) # Average loss and accuracy across processes for logging if args.distributed: reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size) reduced_loss_self_training = utils.reduce_tensor(loss_self_training.data, args.world_size) prec1 = utils.reduce_tensor(prec1, args.world_size) else: reduced_loss_s = loss_s.data reduced_loss_self_training = loss_self_training.data # to_python_float incurs a host<->device sync losses_s.update(to_python_float(reduced_loss_s), input_s.size(0)) losses_self_training.update(to_python_float(reduced_loss_self_training), input_s.size(0)) top1.update(to_python_float(prec1), input_s.size(0)) global_step = epoch * num_iterations + i torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq) end = time.time() if args.local_rank == 0: writer.add_scalar('train/top1', to_python_float(prec1), global_step) writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step) writer.add_scalar("train/loss (self training)", to_python_float(reduced_loss_self_training), global_step) writer.add_figure('train/predictions vs. actuals', utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names, metadata_s, train_labeled_loader.dataset.metadata_map), global_step=global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t' 'Loss (self training) {loss_self_training.val:.10f} ({loss_self_training.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, i, len(train_labeled_loader), args.world_size * args.batch_size[0] / batch_time.val, args.world_size * args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss_s=losses_s, loss_self_training=losses_self_training, top1=top1)) if __name__ == '__main__': model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='FixMatch') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: fmow)') parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ]) parser.add_argument('--test-list', nargs='+', default=["val", "test"]) parser.add_argument('--metric', default="acc_worst_region") parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', help='Random resize scale (default: 0.5 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--bottleneck-dim', default=512, type=int, help='Dimension of bottleneck') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') parser.add_argument('--threshold', default=0.7, type=float, help='confidence threshold') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Initial learning rate. Will be scaled by /256: ' 'args.lr = args.lr*float(args.batch_size*args.world_size)/256.') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int) parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') parser.add_argument('--opt-level', type=str) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--channels-last', type=bool, default=False) parser.add_argument("--log", type=str, default='fixmatch', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_image_classification/fixmatch.sh ================================================ CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/fixmatch/fmow/lr_0_1_aa_v0_densenet121 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" \ --lr 0.3 --opt-level O1 --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 12 -b 24 24 -p 500 \ --metric "F1-macro_all" --log logs/fixmatch/iwildcam/lr_0_3_deterministic ================================================ FILE: examples/domain_adaptation/wilds_image_classification/jan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import argparse import os import shutil import time import pprint import math from itertools import cycle import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import torchvision.models as models from torch.utils.tensorboard import SummaryWriter from timm.loss.cross_entropy import LabelSmoothingCrossEntropy import wilds try: from apex.parallel import DistributedDataParallel as DDP from apex.fp16_utils import * from apex import amp, optimizers from apex.multi_tensor_apply import multi_tensor_applier except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") import utils from tllib.alignment.jan import JointMultipleKernelMaximumMeanDiscrepancy, ImageClassifier as Classifier from tllib.modules.kernels import GaussianKernel from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter from tllib.utils.metric import accuracy def main(args): writer = None if args.local_rank == 0: logger = CompleteLogger(args.log, args.phase) if args.phase == 'train': writer = SummaryWriter(args.log) pprint.pprint(args) print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True best_prec1 = 0 if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format # Data loading code train_transform = utils.get_train_transform( img_size=args.img_size, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.interpolation, ) val_transform = utils.get_val_transform( img_size=args.img_size, crop_pct=args.crop_pct, interpolation=args.interpolation, ) if args.local_rank == 0: print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \ utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list, train_transform, val_transform, verbose=args.local_rank == 0) # create model if args.local_rank == 0: if not args.scratch: print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch) if args.sync_bn: import apex if args.local_rank == 0: print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda().to(memory_format=memory_format) # Scale learning rate based on global batch size args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256. optimizer = torch.optim.SGD( model.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale ) # Use cosine annealing learning rate strategy lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr) ) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) # define loss function (criterion) if args.smoothing: criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda() else: criterion = nn.CrossEntropyLoss().cuda() # Data loading code train_labeled_sampler = None train_unlabeled_sampler = None if args.distributed: train_labeled_sampler = DistributedSampler(train_labeled_dataset) train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True) train_unlabeled_loader = DataLoader( train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True) if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) utils.validate(d, model, -1, writer, args) return # define loss function jmmd_loss = JointMultipleKernelMaximumMeanDiscrepancy( kernels=( [GaussianKernel(alpha=2 ** k) for k in range(-3, 2)], (GaussianKernel(sigma=0.92, track_running_stats=False),) ), linear=args.linear ) for epoch in range(args.epochs): if args.distributed: train_labeled_sampler.set_epoch(epoch) train_unlabeled_sampler.set_epoch(epoch) lr_scheduler.step(epoch) if args.local_rank == 0: print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, train_unlabeled_loader, model, criterion, jmmd_loss, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) prec1 = utils.validate(d, model, epoch, writer, args) # remember best prec@1 and save checkpoint if args.local_rank == 0: is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) def train(train_labeled_loader, train_unlabeled_loader, model, criterion, jmmd_loss, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_trans = AverageMeter('Loss (transfer)', ':3.2f') top1 = AverageMeter('Top 1', ':3.1f') # switch to train mode model.train() end = time.time() num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader)) for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \ zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)): # compute output n_s, n_t = len(input_s), len(input_t) input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0) output, feature = model(input) output_s, output_t = output.split([n_s, n_t], dim=0) feature_s, feature_t = feature.split([n_s, n_t], dim=0) loss_s = criterion(output_s, target_s.cuda()) loss_trans = jmmd_loss( (feature_s, F.softmax(output_s, dim=1)), (feature_t, F.softmax(output_t, dim=1)) ) loss = loss_s + loss_trans * args.trade_off # compute gradient and do SGD step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,)) # Average loss and accuracy across processes for logging if args.distributed: reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size) reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size) prec1 = utils.reduce_tensor(prec1, args.world_size) else: reduced_loss_s = loss_s.data reduced_loss_trans = loss_trans.data # to_python_float incurs a host<->device sync losses_s.update(to_python_float(reduced_loss_s), input_s.size(0)) losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0)) top1.update(to_python_float(prec1), input_s.size(0)) global_step = epoch * num_iterations + i torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq) end = time.time() if args.local_rank == 0: writer.add_scalar('train/top1', to_python_float(prec1), global_step) writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step) writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step) writer.add_figure('train/predictions vs. actuals', utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names, metadata_s, train_labeled_loader.dataset.metadata_map), global_step=global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t' 'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, i, len(train_labeled_loader), args.world_size * args.batch_size[0] / batch_time.val, args.world_size * args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans, top1=top1)) if __name__ == '__main__': model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='JAN') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: fmow)') parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ]) parser.add_argument('--test-list', nargs='+', default=["val", "test"]) parser.add_argument('--metric', default="acc_worst_region") parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', help='Random resize scale (default: 0.5 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--bottleneck-dim', default=512, type=int, help='Dimension of bottleneck') parser.add_argument('--linear', default=False, action='store_true', help='whether use the linear version') parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Initial learning rate. Will be scaled by /256: ' 'args.lr = args.lr*float(args.batch_size*args.world_size)/256.') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int) parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') parser.add_argument('--opt-level', type=str) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--channels-last', type=bool, default=False) parser.add_argument("--log", type=str, default='jan', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_image_classification/jan.sh ================================================ CUDA_VISIBLE_DEVICES=0 python jan.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/jan/fmow/lr_0_1_aa_v0_densenet121 CUDA_VISIBLE_DEVICES=0 python jan.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \ --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ --log logs/jan/iwildcam/lr_0_3_deterministic ================================================ FILE: examples/domain_adaptation/wilds_image_classification/mdd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import argparse import os import shutil import time import pprint import math from itertools import cycle import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim import torch.utils.data from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler import torchvision.models as models from torch.utils.tensorboard import SummaryWriter from timm.loss.cross_entropy import LabelSmoothingCrossEntropy import wilds try: from apex.parallel import DistributedDataParallel as DDP from apex.fp16_utils import * from apex import amp, optimizers from apex.multi_tensor_apply import multi_tensor_applier except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") import utils from tllib.alignment.mdd import ClassificationMarginDisparityDiscrepancy \ as MarginDisparityDiscrepancy, ImageClassifier as Classifier from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter from tllib.utils.metric import accuracy def main(args): writer = None if args.local_rank == 0: logger = CompleteLogger(args.log, args.phase) if args.phase == 'train': writer = SummaryWriter(args.log) pprint.pprint(args) print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True best_prec1 = 0 if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format # Data loading code train_transform = utils.get_train_transform( img_size=args.img_size, scale=args.scale, ratio=args.ratio, hflip=args.hflip, vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation=args.interpolation, ) val_transform = utils.get_val_transform( img_size=args.img_size, crop_pct=args.crop_pct, interpolation=args.interpolation, ) if args.local_rank == 0: print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \ utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list, train_transform, val_transform, verbose=args.local_rank == 0) # create model if args.local_rank == 0: if not args.scratch: print("=> using pre-trained model '{}'".format(args.arch)) else: print("=> creating model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrain=not args.scratch) pool_layer = nn.Identity() if args.no_pool else None model = Classifier(backbone, args.num_classes, bottleneck_dim=args.bottleneck_dim, width=args.bottleneck_dim, pool_layer=pool_layer, finetune=not args.scratch) mdd = MarginDisparityDiscrepancy(args.margin) if args.sync_bn: import apex if args.local_rank == 0: print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda().to(memory_format=memory_format) # Scale learning rate based on global batch size args.lr = args.lr * float(args.batch_size[0] * args.world_size) / 256. optimizer = torch.optim.SGD( model.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale ) # Use cosine annealing learning rate strategy lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: max((math.cos(float(x) / args.epochs * math.pi) * 0.5 + 0.5) * args.lr, args.min_lr) ) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) # define loss function (criterion) if args.smoothing: criterion = LabelSmoothingCrossEntropy(args.smoothing).cuda() else: criterion = nn.CrossEntropyLoss().cuda() # Data loading code train_labeled_sampler = None train_unlabeled_sampler = None if args.distributed: train_labeled_sampler = DistributedSampler(train_labeled_dataset) train_unlabeled_sampler = DistributedSampler(train_unlabeled_dataset) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, drop_last=True) train_unlabeled_loader = DataLoader( train_unlabeled_dataset, batch_size=args.batch_size[1], shuffle=(train_unlabeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_unlabeled_sampler, drop_last=True) if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) utils.validate(d, model, -1, writer, args) return for epoch in range(args.epochs): if args.distributed: train_labeled_sampler.set_epoch(epoch) train_unlabeled_sampler.set_epoch(epoch) lr_scheduler.step(epoch) if args.local_rank == 0: print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, train_unlabeled_loader, model, criterion, mdd, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) prec1 = utils.validate(d, model, epoch, writer, args) # remember best prec@1 and save checkpoint if args.local_rank == 0: is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) def train(train_labeled_loader, train_unlabeled_loader, model, criterion, mdd, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses_s = AverageMeter('Loss (s)', ':3.2f') losses_trans = AverageMeter('Loss (transfer)', ':3.2f') top1 = AverageMeter('Top 1', ':3.1f') # switch to train mode model.train() end = time.time() num_iterations = min(len(train_labeled_loader), len(train_unlabeled_loader)) for i, (input_s, target_s, metadata_s), (input_t, metadata_t) in \ zip(range(num_iterations), train_labeled_loader, cycle(train_unlabeled_loader)): # compute output n_s, n_t = len(input_s), len(input_t) input = torch.cat([input_s.cuda(), input_t.cuda()], dim=0) output, output_adv = model(input) output_s, output_t = output.split([n_s, n_t], dim=0) output_adv_s, output_adv_t = output_adv.split([n_s, n_t], dim=0) loss_s = criterion(output_s, target_s.cuda()) loss_trans = -mdd(output_s, output_adv_s, output_t, output_adv_t) loss = loss_s + loss_trans * args.trade_off # compute gradient and do SGD step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if args.distributed: model.module.step() else: model.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy prec1, = accuracy(output_s.data, target_s.cuda(), topk=(1,)) # Average loss and accuracy across processes for logging if args.distributed: reduced_loss_s = utils.reduce_tensor(loss_s.data, args.world_size) reduced_loss_trans = utils.reduce_tensor(loss_trans.data, args.world_size) prec1 = utils.reduce_tensor(prec1, args.world_size) else: reduced_loss_s = loss_s.data reduced_loss_trans = loss_trans.data # to_python_float incurs a host<->device sync losses_s.update(to_python_float(reduced_loss_s), input_s.size(0)) losses_trans.update(to_python_float(reduced_loss_trans), input_s.size(0)) top1.update(to_python_float(prec1), input_s.size(0)) global_step = epoch * num_iterations + i torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq) end = time.time() if args.local_rank == 0: writer.add_scalar('train/top1', to_python_float(prec1), global_step) writer.add_scalar("train/loss (s)", to_python_float(reduced_loss_s), global_step) writer.add_scalar("train/loss (trans)", to_python_float(reduced_loss_trans), global_step) writer.add_figure('train/predictions vs. actuals', utils.plot_classes_preds(input_s.cpu(), target_s, output_s.cpu(), args.class_names, metadata_s, train_labeled_loader.dataset.metadata_map), global_step=global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss (s) {loss_s.val:.10f} ({loss_s.avg:.4f})\t' 'Loss (trans) {loss_trans.val:.10f} ({loss_trans.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, i, len(train_labeled_loader), args.world_size * args.batch_size[0] / batch_time.val, args.world_size * args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss_s=losses_s, loss_trans=losses_trans, top1=top1)) if __name__ == '__main__': model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) parser = argparse.ArgumentParser(description='MDD') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='fmow', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: fmow)') parser.add_argument('--unlabeled-list', nargs='+', default=["test_unlabeled", ]) parser.add_argument('--test-list', nargs='+', default=["val", "test"]) parser.add_argument('--metric', default="acc_worst_region") parser.add_argument('--img-size', type=int, default=(224, 224), metavar='N', nargs='+', help='Image patch size (default: None => model default)') parser.add_argument('--crop-pct', default=utils.DEFAULT_CROP_PCT, type=float, metavar='N', help='Input image center crop percent (for validation only)') parser.add_argument('--interpolation', default='bicubic', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', help='Random resize scale (default: 0.5 1.0)') parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') parser.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') parser.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet50', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') parser.add_argument('--bottleneck-dim', default=2048, type=int) parser.add_argument('--margin', type=float, default=4., help="margin gamma") parser.add_argument('--trade-off', default=1., type=float, help='the trade-off hyper-parameter for transfer loss') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Initial learning rate. Will be scaled by /256: ' 'args.lr = args.lr*float(args.batch_size*args.world_size)/256.') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-6)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int) parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') parser.add_argument('--opt-level', type=str) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--channels-last', type=bool, default=False) parser.add_argument("--log", type=str, default='mdd', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_image_classification/mdd.sh ================================================ CUDA_VISIBLE_DEVICES=0 python mdd.py data/wilds -d "fmow" --aa "v0" --arch "densenet121" \ --lr 0.1 --opt-level O1 --deterministic --vflip 0.5 --log logs/mdd/fmow/lr_0_1_aa_v0_densenet121 CUDA_VISIBLE_DEVICES=0 python mdd.py data/wilds -d "iwildcam" --aa "v0" --unlabeled-list "extra_unlabeled" --lr 0.3 --opt-level O1 \ --deterministic --img-size 448 448 --crop-pct 1.0 --scale 1.0 1.0 --epochs 18 -b 24 24 --trade-off 0.3 -p 500 --metric "F1-macro_all" \ --log logs/mdd/iwildcam/lr_0_3_deterministic ================================================ FILE: examples/domain_adaptation/wilds_image_classification/requirements.txt ================================================ wilds timm tensorflow tensorboard ================================================ FILE: examples/domain_adaptation/wilds_image_classification/utils.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import time import math import matplotlib.pyplot as plt import numpy as np import sys import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler from torch.utils.data import DataLoader, ConcatDataset from torchvision import transforms from PIL import Image import wilds from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform import timm sys.path.append('../../..') from tllib.vision.transforms import Denormalize from tllib.utils.meter import AverageMeter, ProgressMeter def get_model_names(): return timm.list_models() def get_model(model_name, pretrain=True): # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=pretrain) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() return backbone def get_dataset(dataset_name, root, unlabeled_list=("test_unlabeled",), test_list=("test",), transform_train=None, transform_test=None, verbose=True, transform_train_target=None): if transform_train_target is None: transform_train_target = transform_train labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True) unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True) num_classes = labeled_dataset.n_classes train_labeled_dataset = labeled_dataset.get_subset("train", transform=transform_train) train_unlabeled_datasets = [ unlabeled_dataset.get_subset(u, transform=transform_train_target) for u in unlabeled_list ] train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets) test_datasets = [ labeled_dataset.get_subset(t, transform=transform_test) for t in test_list ] if dataset_name == "fmow": from wilds.datasets.fmow_dataset import categories class_names = categories else: class_names = list(range(num_classes)) if verbose: print("Datasets") for n, d in zip(["train"] + unlabeled_list + test_list, [train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets): print("\t{}:{}".format(n, len(d))) print("\t#classes:", num_classes) return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_classes, class_names def collate_list(vec): """ Adapted from https://github.com/p-lambda/wilds If vec is a list of Tensors, it concatenates them all along the first dimension. If vec is a list of lists, it joins these lists together, but does not attempt to recursively collate. This allows each element of the list to be, e.g., its own dict. If vec is a list of dicts (with the same keys in each dict), it returns a single dict with the same keys. For each key, it recursively collates all entries in the list. """ if not isinstance(vec, list): raise TypeError("collate_list must take in a list") elem = vec[0] if torch.is_tensor(elem): return torch.cat(vec) elif isinstance(elem, list): return [obj for sublist in vec for obj in sublist] elif isinstance(elem, dict): return {k: collate_list([d[k] for d in vec]) for k in elem} else: raise TypeError("Elements of the list to collate must be tensors or dicts.") def get_train_transform(img_size, scale=None, ratio=None, hflip=0.5, vflip=0., color_jitter=0.4, auto_augment=None, interpolation='bilinear'): scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range ratio = tuple(ratio or (3. / 4., 4. / 3.)) # default imagenet ratio range transforms_list = [ transforms.RandomResizedCrop(img_size, scale=scale, ratio=ratio, interpolation=_pil_interp(interpolation))] if hflip > 0.: transforms_list += [transforms.RandomHorizontalFlip(p=hflip)] if vflip > 0.: transforms_list += [transforms.RandomVerticalFlip(p=vflip)] if auto_augment: assert isinstance(auto_augment, str) if isinstance(img_size, (tuple, list)): img_size_min = min(img_size) else: img_size_min = img_size aa_params = dict( translate_const=int(img_size_min * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in IMAGENET_DEFAULT_MEAN]), ) if interpolation and interpolation != 'random': aa_params['interpolation'] = _pil_interp(interpolation) if auto_augment.startswith('rand'): transforms_list += [rand_augment_transform(auto_augment, aa_params)] elif auto_augment.startswith('augmix'): aa_params['translate_pct'] = 0.3 transforms_list += [augment_and_mix_transform(auto_augment, aa_params)] else: transforms_list += [auto_augment_transform(auto_augment, aa_params)] elif color_jitter is not None: # color jitter is enabled when not using AA if isinstance(color_jitter, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue assert len(color_jitter) in (3, 4) else: # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue color_jitter = (float(color_jitter),) * 3 transforms_list += [transforms.ColorJitter(*color_jitter)] transforms_list += [ transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ] return transforms.Compose(transforms_list) def get_val_transform(img_size=224, crop_pct=None, interpolation='bilinear'): crop_pct = crop_pct or DEFAULT_CROP_PCT if isinstance(img_size, (tuple, list)): assert len(img_size) == 2 if img_size[-1] == img_size[-2]: # fall-back to older behaviour so Resize scales to shortest edge if target is square scale_size = int(math.floor(img_size[0] / crop_pct)) else: scale_size = tuple([int(x / crop_pct) for x in img_size]) else: scale_size = int(math.floor(img_size / crop_pct)) return transforms.Compose([ transforms.Resize(scale_size, _pil_interp(interpolation)), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) def _pil_interp(method): if method == 'bicubic': return Image.BICUBIC elif method == 'lanczos': return Image.LANCZOS elif method == 'hamming': return Image.HAMMING else: # default bilinear, do we want to allow nearest? return Image.BILINEAR def validate(val_dataset, model, epoch, writer, args): val_sampler = None if args.distributed: val_sampler = DistributedSampler(val_dataset) val_loader = DataLoader( val_dataset, batch_size=args.batch_size[0], shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) all_y_true = [] all_y_pred = [] all_metadata = [] sampled_inputs = [] sampled_outputs = [] sampled_targets = [] sampled_metadata = [] batch_time = AverageMeter('Time', ':6.3f') progress = ProgressMeter( len(val_loader), [batch_time], prefix='Test: ') # switch to evaluate mode model.eval() end = time.time() for i, (input, target, metadata) in enumerate(val_loader): # compute output with torch.no_grad(): output = model(input.cuda()).cpu() all_y_true.append(target) all_y_pred.append(output.argmax(1)) all_metadata.append(metadata) sampled_inputs.append(input[0:1]) sampled_targets.append(target[0:1]) sampled_outputs.append(output[0:1]) sampled_metadata.append(metadata[0:1]) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.local_rank == 0 and i % args.print_freq == 0: progress.display(i) if args.local_rank == 0: writer.add_figure( 'test/predictions vs. actuals', plot_classes_preds( collate_list(sampled_inputs), collate_list(sampled_targets), collate_list(sampled_outputs), args.class_names, collate_list(sampled_metadata), val_dataset.metadata_map, nrows=min(int(len(val_loader) / 4), 50) ), global_step=epoch ) # evaluate results = val_dataset.eval( collate_list(all_y_pred), collate_list(all_y_true), collate_list(all_metadata) ) print(results[1]) for k, v in results[0].items(): if v == 0 or "Other" in k: continue writer.add_scalar("test/{}".format(k), v, global_step=epoch) return results[0][args.metric] def reduce_tensor(tensor, world_size): rt = tensor.clone() dist.all_reduce(rt, op=dist.reduce_op.SUM) rt /= world_size return rt def matplotlib_imshow(img): """helper function to show an image""" img = Denormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)(img) img = np.transpose(img.numpy(), (1, 2, 0)) plt.imshow(img) def plot_classes_preds(images, labels, outputs, class_names, metadata, metadata_map, nrows=4): ''' Generates matplotlib Figure using a trained network, along with images and labels from a batch, that shows the network's top prediction along with its probability, alongside the actual label, coloring this information based on whether the prediction was correct or not. Uses the "images_to_probs" function. ''' # convert output probabilities to predicted class _, preds_tensor = torch.max(outputs, 1) preds = np.squeeze(preds_tensor.numpy()) probs = [F.softmax(el, dim=0)[i].item() for i, el in zip(preds, outputs)] # plot the images in the batch, along with predicted and true labels fig = plt.figure(figsize=(12, nrows * 4)) domains = get_domain_names(metadata, metadata_map) for idx in np.arange(min(nrows * 4, len(images))): ax = fig.add_subplot(nrows, 4, idx + 1, xticks=[], yticks=[]) matplotlib_imshow(images[idx]) ax.set_title("{0}, {1:.1f}%\n(label: {2}\ndomain: {3})".format( class_names[preds[idx]], probs[idx] * 100.0, class_names[labels[idx]], domains[idx], ), color=("green" if preds[idx] == labels[idx].item() else "red")) return fig def get_domain_names(metadata, metadata_map): return get_domain_ids(metadata) def get_domain_ids(metadata): return [int(m[0]) for m in metadata] ================================================ FILE: examples/domain_adaptation/wilds_ogb_molpcba/README.md ================================================ # Unsupervised Domain Adaptation for WILDS (Molecule classification) ## Installation It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results. Then, you need to run ``` pip install -r requirements.txt ``` At last, you need to install torch_sparse following `https://github.com/rusty1s/pytorch_sparse`. ## Dataset Following datasets can be downloaded automatically: - [OGB-MolPCBA (WILDS)](https://wilds.stanford.edu/datasets/) ## Supported Methods TODO ## Usage The shell files give all the training scripts we use, e.g. ``` CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --lr 3e-2 -b 4096 4096 --epochs 200 \ --seed 0 --deterministic --log logs/erm/obg_lr_0_03_deterministic ``` ## Results ### Performance on WILDS-OGB-MolPCBA (GIN-virtual) | Methods | Val Avg Precision | Test Avg Precision | GPU Memory Usage(GB)| | --- | --- | --- | --- | | ERM | 29.0 | 28.0 | 17.8 | ### Visualization We use tensorboard to record the training process and visualize the outputs of the models. ``` tensorboard --logdir=logs ``` ================================================ FILE: examples/domain_adaptation/wilds_ogb_molpcba/erm.py ================================================ """ @author: Jiaxin Li @contact: thulijx@gmail.com """ import argparse import shutil import time import pprint import torch import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter import wilds import utils from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter def main(args): logger = CompleteLogger(args.log, args.phase) writer = SummaryWriter(args.log) pprint.pprint(args) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) # Data loading code # There are no well-developed data augmentation techniques for molecular graphs. train_transform = None val_transform = None print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_classes, args.class_names = \ utils.get_dataset('ogb-molpcba', args.data_dir, args.unlabeled_list, args.test_list, train_transform, val_transform, use_unlabeled=args.use_unlabeled, verbose=True) # create model print("=> creating model '{}'".format(args.arch)) model = utils.get_model(args.arch, args.num_classes) model = model.cuda().to() optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay ) # Data loading code train_labeled_sampler = None train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler, collate_fn=train_labeled_dataset.collate) # define loss function (criterion) criterion = utils.reduced_bce_logit_loss if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): print(n) utils.validate(d, model, -1, writer, args) return # start training best_val_metric = 0 test_metric = 0 for epoch in range(args.epochs): # train for one epoch train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): print(n) if n == 'val': tmp_val_metric = utils.validate(d, model, epoch, writer, args) elif n == 'test': tmp_test_metric = utils.validate(d, model, epoch, writer, args) # remember best mse and save checkpoint is_best = tmp_val_metric > best_val_metric best_val_metric = max(tmp_val_metric, best_val_metric) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: test_metric = tmp_test_metric shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) print("best val performance: {:.3f}".format(best_val_metric)) print("test performance: {:.3f}".format(test_metric)) logger.close() writer.close() def train(train_loader, model, criterion, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses = AverageMeter('Loss', ':3.2f') # switch to train mode model.train() end = time.time() for i, (input, target, metadata) in enumerate(train_loader): # compute output output = model(input.cuda()) loss = criterion(output, target.cuda()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. losses.update(loss, input.size(0)) global_step = epoch * len(train_loader) + i batch_time.update((time.time() - end) / args.print_freq) end = time.time() writer.add_scalar('train/loss', loss, global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss {loss.val:.10f} ({loss.avg:.4f})\t'.format( epoch, i, len(train_loader), args.batch_size[0] / batch_time.val, args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss=losses)) if __name__ == '__main__': model_names = utils.get_model_names() parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='ogb-molpcba', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: ogb-molpcba)') parser.add_argument('--unlabeled-list', nargs='+', default=[]) parser.add_argument('--test-list', nargs='+', default=['val', 'test']) parser.add_argument('--metric', default='ap', help='metric used to evaluate model performance. (default: average precision)') parser.add_argument('--use-unlabeled', action='store_true', help='Whether use unlabeled data for training or not.') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='gin_virtual', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: gin_virtual)') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Learning rate') parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='W', help='weight decay (default: 0.0)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=50, type=int, metavar='N', help='print frequency (default: 50)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument('--log', type=str, default='src_only', help='Where to save logs, checkpoints and debugging images.') parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_ogb_molpcba/erm.sh ================================================ # ogb-molpcba CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --lr 3e-2 -b 4096 4096 --epochs 200 \ --seed 0 --deterministic --log logs/erm/obg_lr_0_03_deterministic ================================================ FILE: examples/domain_adaptation/wilds_ogb_molpcba/gin.py ================================================ """ Adapted from "https://github.com/p-lambda/wilds" @author: Jiaxin Li @contact: thulijx@gmail.com """ import torch from torch_geometric.nn import MessagePassing from torch_geometric.nn import global_mean_pool, global_add_pool import torch.nn.functional as F from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder __all__ = ['gin_virtual'] class GINVirtual(torch.nn.Module): """ Graph Isomorphism Network augmented with virtual node for multi-task binary graph classification. Args: num_tasks (int): number of binary label tasks. default to 128 (number of tasks of ogbg-molpcba) num_layers (int): number of message passing layers of GNN emb_dim (int): dimensionality of hidden channels dropout (float): dropout ratio applied to hidden channels Inputs: - batched Pytorch Geometric graph object Outputs: - prediction (tensor): float torch tensor of shape (num_graphs, num_tasks) """ def __init__(self, num_tasks=128, num_layers=5, emb_dim=300, dropout=0.5): super(GINVirtual, self).__init__() self.num_layers = num_layers self.dropout = dropout self.emb_dim = emb_dim self.num_tasks = num_tasks if num_tasks is None: self.d_out = self.emb_dim else: self.d_out = self.num_tasks if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") # GNN to generate node embeddings self.gnn_node = GINVirtualNode(num_layers, emb_dim, dropout=dropout) # Pooling function to generate whole-graph embeddings self.pool = global_mean_pool if num_tasks is None: self.graph_pred_linear = None else: self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks) def forward(self, batched_data): h_node = self.gnn_node(batched_data) h_graph = self.pool(h_node, batched_data.batch) if self.graph_pred_linear is None: return h_graph else: return self.graph_pred_linear(h_graph) class GINVirtualNode(torch.nn.Module): """ Helper function of Graph Isomorphism Network augmented with virtual node for multi-task binary graph classification This will generate node embeddings. Args: num_layers (int): number of message passing layers of GNN emb_dim (int): dimensionality of hidden channels dropout (float, optional): dropout ratio applied to hidden channels. Default: 0.5 Inputs: - batched Pytorch Geometric graph object Outputs: - node_embedding (tensor): float torch tensor of shape (num_nodes, emb_dim) """ def __init__(self, num_layers, emb_dim, dropout=0.5): super(GINVirtualNode, self).__init__() self.num_layers = num_layers self.dropout = dropout if self.num_layers < 2: raise ValueError("Number of GNN layers must be greater than 1.") self.atom_encoder = AtomEncoder(emb_dim) # set the initial virtual node embedding to 0. self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim) torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0) # List of GNNs self.convs = torch.nn.ModuleList() # batch norms applied to node embeddings self.batch_norms = torch.nn.ModuleList() # List of MLPs to transform virtual node at every layer self.mlp_virtualnode_list = torch.nn.ModuleList() for layer in range(num_layers): self.convs.append(GINConv(emb_dim)) self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim)) for layer in range(num_layers - 1): self.mlp_virtualnode_list.append( torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU())) def forward(self, batched_data): x, edge_index, edge_attr, batch = batched_data.x, batched_data.edge_index, batched_data.edge_attr, batched_data.batch # virtual node embeddings for graphs virtualnode_embedding = self.virtualnode_embedding( torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) h_list = [self.atom_encoder(x)] for layer in range(self.num_layers): # add message from virtual nodes to graph nodes h_list[layer] = h_list[layer] + virtualnode_embedding[batch] # Message passing among graph nodes h = self.convs[layer](h_list[layer], edge_index, edge_attr) h = self.batch_norms[layer](h) if layer == self.num_layers - 1: # remove relu for the last layer h = F.dropout(h, self.dropout, training=self.training) else: h = F.dropout(F.relu(h), self.dropout, training=self.training) h_list.append(h) # update the virtual nodes if layer < self.num_layers - 1: # add message from graph nodes to virtual nodes virtualnode_embedding_temp = global_add_pool(h_list[layer], batch) + virtualnode_embedding # transform virtual nodes using MLP virtualnode_embedding = F.dropout(self.mlp_virtualnode_list[layer](virtualnode_embedding_temp), self.dropout, training=self.training) node_embedding = h_list[-1] return node_embedding class GINConv(MessagePassing): """ Graph Isomorphism Network message passing. Args: emb_dim (int): node embedding dimensionality Inputs: - x (tensor): node embedding - edge_index (tensor): edge connectivity information - edge_attr (tensor): edge feature Outputs: - prediction (tensor): output node embedding """ def __init__(self, emb_dim): super(GINConv, self).__init__(aggr="add") self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim), torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim)) self.eps = torch.nn.Parameter(torch.Tensor([0])) self.bond_encoder = BondEncoder(emb_dim=emb_dim) def forward(self, x, edge_index, edge_attr): edge_embedding = self.bond_encoder(edge_attr) out = self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) return out def message(self, x_j, edge_attr): return F.relu(x_j + edge_attr) def update(self, aggr_out): return aggr_out def gin_virtual(num_tasks, dropout=0.5): model = GINVirtual(num_tasks=num_tasks, dropout=dropout) return model ================================================ FILE: examples/domain_adaptation/wilds_ogb_molpcba/requirements.txt ================================================ torch_geometric wilds tensorflow tensorboard ogb ================================================ FILE: examples/domain_adaptation/wilds_ogb_molpcba/utils.py ================================================ """ @author: Jiaxin Li @contact: thulijx@gmail.com """ import time import sys import torch import torch.nn as nn from torch.utils.data import DataLoader, ConcatDataset import wilds sys.path.append('../../..') import gin as models from tllib.utils.meter import AverageMeter, ProgressMeter def reduced_bce_logit_loss(y_pred, y_target): """ Every item of y_target has n elements which may be labeled by nan. Nan values should not be used while calculating loss. So extract elements which are not nan first, and then calculate loss. """ loss = nn.BCEWithLogitsLoss(reduction='none').cuda() is_labeled = ~torch.isnan(y_target) y_pred = y_pred[is_labeled].float() y_target = y_target[is_labeled].float() metrics = loss(y_pred, y_target) return metrics.mean() def get_dataset(dataset_name, root, unlabeled_list=('test_unlabeled',), test_list=('test',), transform_train=None, transform_test=None, use_unlabeled=True, verbose=True): labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True) train_labeled_dataset = labeled_dataset.get_subset('train', transform=transform_train) if use_unlabeled: unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True) train_unlabeled_datasets = [ unlabeled_dataset.get_subset(u, transform=transform_train) for u in unlabeled_list ] train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets) else: unlabeled_list = [] train_unlabeled_datasets = [] train_unlabeled_dataset = None test_datasets = [ labeled_dataset.get_subset(t, transform=transform_test) for t in test_list ] if dataset_name == 'ogb-molpcba': num_classes = labeled_dataset.y_size else: num_classes = labeled_dataset.n_classes class_names = list(range(num_classes)) if verbose: print('Datasets') for n, d in zip(['train'] + unlabeled_list + test_list, [train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets): print('\t{}:{}'.format(n, len(d))) print('\t#classes:', num_classes) return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_classes, class_names def get_model_names(): return sorted(name for name in models.__dict__ if name.islower() and not name.startswith('__') and callable(models.__dict__[name])) def get_model(arch, num_classes): if arch in models.__dict__: model = models.__dict__[arch](num_tasks=num_classes) else: raise ValueError('{} is not supported'.format(arch)) return model def collate_list(vec): """ Adapted from https://github.com/p-lambda/wilds If vec is a list of Tensors, it concatenates them all along the first dimension. If vec is a list of lists, it joins these lists together, but does not attempt to recursively collate. This allows each element of the list to be, e.g., its own dict. If vec is a list of dicts (with the same keys in each dict), it returns a single dict with the same keys. For each key, it recursively collates all entries in the list. """ if not isinstance(vec, list): raise TypeError("collate_list must take in a list") elem = vec[0] if torch.is_tensor(elem): return torch.cat(vec) elif isinstance(elem, list): return [obj for sublist in vec for obj in sublist] elif isinstance(elem, dict): return {k: collate_list([d[k] for d in vec]) for k in elem} else: raise TypeError("Elements of the list to collate must be tensors or dicts.") def validate(val_dataset, model, epoch, writer, args): val_sampler = None val_loader = DataLoader( val_dataset, batch_size=args.batch_size[0], shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler, collate_fn=val_dataset.collate) all_y_true = [] all_y_pred = [] all_metadata = [] batch_time = AverageMeter('Time', ':6.3f') progress = ProgressMeter( len(val_loader), [batch_time], prefix='Test: ') # switch to evaluate mode model.eval() end = time.time() for i, (input, target, metadata) in enumerate(val_loader): # compute output with torch.no_grad(): output = model(input.cuda()).cpu() all_y_true.append(target) all_y_pred.append(output) all_metadata.append(metadata) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.local_rank == 0 and i % args.print_freq == 0: progress.display(i) # evaluate results = val_dataset.eval( collate_list(all_y_pred), collate_list(all_y_true), collate_list(all_metadata) ) print(results[1]) for k, v in results[0].items(): if v == 0 or "Other" in k: continue writer.add_scalar("test/{}".format(k), v, global_step=epoch) return results[0][args.metric] ================================================ FILE: examples/domain_adaptation/wilds_poverty/README.md ================================================ # Unsupervised Domain Adaptation for WILDS (Image Regression) ## Installation It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results. You need to install apex following `https://github.com/NVIDIA/apex`. Then run ``` pip install -r requirements.txt ``` ## Dataset Following datasets can be downloaded automatically: - [PovertyMap (WILDS)](https://wilds.stanford.edu/datasets/) ## Supported Methods TODO ## Usage Our code is based on [https://github.com/NVIDIA/apex/edit/master/examples/imagenet](https://github.com/NVIDIA/apex/edit/master/examples/imagenet) . It implements Automatic Mixed Precision (Amp) training of popular model architectures, such as ResNet, AlexNet, and VGG, on the WILDS dataset. Command-line flags forwarded to `amp.initialize` are used to easily manipulate and switch between various pure and mixed precision "optimization levels" or `opt_level`s. For a detailed explanation of `opt_level`s, see the [updated API guide](https://nvidia.github.io/apex/amp.html). The shell files give all the training scripts we use, e.g. ``` CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold A \ --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_A ``` ## Results ### Performance on WILDS-PovertyMap (ResNet18-MultiSpectral) | Method | Val Pearson r | Test Pearson r | Val Worst-U/R Pearson r | Test Worst-U/R Pearson r | GPU Memory Usage(GB) | | --- | --- | --- | --- | --- | --- | | ERM | 0.80 | 0.80 | 0.54 | 0.50 | 3.5 | ### Distributed training We uses `apex.parallel.DistributedDataParallel` (DDP) for multiprocess training with one GPU per process. ``` CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 erm.py /data/wilds --arch 'resnet18_ms' \ --opt-level O1 --deterministic --log logs/erm/poverty --lr 1e-3 --wd 0.0 --epochs 200 --metric r_wg --split_scheme official -b 64 64 --fold A ``` ### Visualization We use tensorboard to record the training process and visualize the outputs of the models. ``` tensorboard --logdir=logs ``` ================================================ FILE: examples/domain_adaptation/wilds_poverty/erm.py ================================================ """ @author: Jiaxin Li @contact: thulijx@gmail.com """ import argparse import os import shutil import time import pprint import torch import torch.nn as nn import torch.backends.cudnn as cudnn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.utils.tensorboard import SummaryWriter try: from apex.parallel import DistributedDataParallel as DDP from apex.fp16_utils import * from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") import utils from utils import Regressor from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter def main(args): writer = None if args.local_rank == 0: logger = CompleteLogger(args.log, args.phase) if args.phase == 'train': writer = SummaryWriter(args.log) pprint.pprint(args) print("opt_level = {}".format(args.opt_level)) print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32)) print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale)) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 args.gpu = 0 args.world_size = 1 if args.distributed: args.gpu = args.local_rank torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." if args.channels_last: memory_format = torch.channels_last else: memory_format = torch.contiguous_format # Data loading code # Images in povertyMap dataset have 8 channels and traditional data augmentation # methods have no effect on performance. train_transform = None val_transform = None if args.local_rank == 0: print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, args.num_channels = \ utils.get_dataset('poverty', args.data_dir, args.unlabeled_list, args.test_list, args.split_scheme, train_transform, val_transform, use_unlabeled=args.use_unlabeled, verbose=args.local_rank == 0, fold=args.fold) # create model if args.local_rank == 0: print("=> creating model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, args.num_channels) pool_layer = nn.Identity() if args.no_pool else None model = Regressor(backbone, pool_layer=pool_layer, finetune=False) if args.sync_bn: import apex if args.local_rank == 0: print("using apex synced BN") model = apex.parallel.convert_syncbn_model(model) model = model.cuda().to(memory_format=memory_format) optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.weight_decay) # Initialize Amp. Amp accepts either values or strings for the optional override arguments, # for convenient interoperation with argparse. model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level, keep_batchnorm_fp32=args.keep_batchnorm_fp32, loss_scale=args.loss_scale ) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=args.gamma, step_size=args.step_size) # For distributed training, wrap the model with apex.parallel.DistributedDataParallel. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks. if args.distributed: # By default, apex.parallel.DistributedDataParallel overlaps communication with # computation in the backward pass. # model = DDP(model) # delay_allreduce delays all communication to the end of the backward pass. model = DDP(model, delay_allreduce=True) # Data loading code train_labeled_sampler = None if args.distributed: train_labeled_sampler = DistributedSampler(train_labeled_dataset) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler) if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) utils.validate(d, model, -1, writer, args) return # start training best_val_metric = 0 test_metric = 0 for epoch in range(args.epochs): if args.distributed: train_labeled_sampler.set_epoch(epoch) lr_scheduler.step(epoch) if args.local_rank == 0: print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, model, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): if args.local_rank == 0: print(n) if n == 'val': tmp_val_metric = utils.validate(d, model, epoch, writer, args) elif n == 'test': tmp_test_metric = utils.validate(d, model, epoch, writer, args) # remember best mse and save checkpoint if args.local_rank == 0: is_best = tmp_val_metric > best_val_metric best_val_metric = max(tmp_val_metric, best_val_metric) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: test_metric = tmp_test_metric shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) print('best val performance: {:.3f}'.format(best_val_metric)) print('test performance: {:.3f}'.format(test_metric)) def train(train_loader, model, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses = AverageMeter('Loss', ':3.2f') # switch to train mode model.train() end = time.time() for i, (input, target, metadata) in enumerate(train_loader): # compute output output, _ = model(input.cuda()) loss = F.mse_loss(output, target.cuda()) # compute gradient and do optimizer step optimizer.zero_grad() with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Average loss and accuracy across processes for logging if args.distributed: reduced_loss = utils.reduce_tensor(loss.data, args.world_size) else: reduced_loss = loss.data # to_python_float incurs a host<->device sync losses.update(to_python_float(reduced_loss), input.size(0)) global_step = epoch * len(train_loader) + i torch.cuda.synchronize() batch_time.update((time.time() - end) / args.print_freq) end = time.time() if args.local_rank == 0: writer.add_scalar("train/loss", to_python_float(reduced_loss), global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss {loss.val:.10f} ({loss.avg:.4f})'.format( epoch, i, len(train_loader), args.world_size * args.batch_size[0] / batch_time.val, args.world_size * args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss=losses)) if __name__ == '__main__': model_names = utils.get_model_names() parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('--unlabeled-list', nargs='+', default=[]) parser.add_argument('--test-list', nargs='+', default=['val', 'test']) parser.add_argument('--metric', default='r_wg', help='metric used to evaluate model performance.' '(default: worst-U/R Pearson r)') parser.add_argument('--split-scheme', type=str, help='Identifies how the train/val/test split is constructed.' 'Choices are dataset-specific.') parser.add_argument('--fold', type=str, default='A', choices=['A', 'B', 'C', 'D', 'E'], help='Fold for poverty dataset. Poverty has 5 different cross validation folds,' 'each splitting the countries differently.') parser.add_argument('--use-unlabeled', action='store_true', help='Whether use unlabeled data for training or not.') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18_ms', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18_ms)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Learning rate') parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)') parser.add_argument('--gamma', type=int, default=0.96, help='parameter for StepLR scheduler') parser.add_argument('--step-size', type=int, default=1, help='parameter for StepLR scheduler') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(64, 64), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (64, 64))') parser.add_argument('--print-freq', '-p', default=50, type=int, metavar='N', help='print frequency (default: 50)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument('--local_rank', default=os.getenv('LOCAL_RANK', 0), type=int) parser.add_argument('--sync-bn', action='store_true', help='enabling apex sync BN.') parser.add_argument('--opt-level', type=str) parser.add_argument('--keep-batchnorm-fp32', type=str, default=None) parser.add_argument('--loss-scale', type=str, default=None) parser.add_argument('--channels-last', type=bool, default=False) parser.add_argument('--log', type=str, default='src_only', help='Where to save logs, checkpoints and debugging images.') parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_poverty/erm.sh ================================================ # official split scheme CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold A \ --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_A CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold B \ --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_B CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold C \ --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_C CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold D \ --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_D CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds --split-scheme official --fold E \ --arch 'resnet18_ms' --lr 1e-3 --epochs 200 -b 64 64 --opt-level O1 --deterministic --log logs/erm/poverty_fold_E ================================================ FILE: examples/domain_adaptation/wilds_poverty/requirements.txt ================================================ wilds tensorflow tensorboard ================================================ FILE: examples/domain_adaptation/wilds_poverty/resnet_ms.py ================================================ """ Modified based on torchvision.models.resnet @author: Jiaxin Li @contact: thulijx@gmail.com """ import torch.nn as nn from torchvision import models from torchvision.models.resnet import BasicBlock, Bottleneck import copy __all__ = ['resnet18_ms', 'resnet34_ms', 'resnet50_ms', 'resnet101_ms', 'resnet152_ms'] class ResNetMS(models.ResNet): """ ResNet with input channels parameter, without fully connected layer. """ def __init__(self, in_channels, *args, **kwargs): super(ResNetMS, self).__init__(*args, **kwargs) self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False) self._out_features = self.fc.in_features nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu') def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # x = self.avgpool(x) # x = torch.flatten(x, 1) # x = self.fc(x) return x @property def out_features(self) -> int: """The dimension of output features""" return self._out_features def copy_head(self) -> nn.Module: """Copy the origin fully connected layer""" return copy.deepcopy(self.fc) def resnet18_ms(num_channels=3): model = ResNetMS(num_channels, BasicBlock, [2, 2, 2, 2]) return model def resnet34_ms(num_channels=3): model = ResNetMS(num_channels, BasicBlock, [3, 4, 6, 3]) return model def resnet50_ms(num_channels=3): model = ResNetMS(num_channels, Bottleneck, [3, 4, 6, 3]) return model def resnet101_ms(num_channels=3): model = ResNetMS(num_channels, Bottleneck, [3, 4, 23, 3]) return model def resnet152_ms(num_channels=3): model = ResNetMS(num_channels, Bottleneck, [3, 8, 36, 3]) return model ================================================ FILE: examples/domain_adaptation/wilds_poverty/utils.py ================================================ """ @author: Jiaxin Li @contact: thulijx@gmail.com """ import time import sys from typing import Tuple, Optional, List, Dict import torch import torch.nn as nn import torch.distributed as dist from torch.utils.data import DataLoader, ConcatDataset from torch.utils.data.distributed import DistributedSampler import wilds import resnet_ms as models sys.path.append('../../..') from tllib.utils.meter import AverageMeter, ProgressMeter class Regressor(nn.Module): """A generic Regressor class for domain adaptation. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1 head (torch.nn.Module, optional): Any regressor head. Use :class:`torch.nn.Linear` by default finetune (bool): Whether finetune the regressor or train from scratch. Default: True .. note:: Different regressors are used in different domain adaptation algorithms to achieve better accuracy respectively, and we provide a suggested `Regressor` for different algorithms. Remember they are not the core of algorithms. You can implement your own `Regressor` and combine it with the domain adaptation algorithm in this algorithm library. .. note:: The learning rate of this regressor is set 10 times to that of the feature extractor for better accuracy by default. If you have other optimization strategies, please over-ride :meth:`~Regressor.get_parameters`. Inputs: - x (tensor): input data fed to `backbone` Outputs: - predictions: regressor's predictions - features: features after `bottleneck` layer and before `head` layer Shape: - Inputs: (minibatch, *) where * means, any number of additional dimensions - predictions: (minibatch, `num_values`) - features: (minibatch, `features_dim`) """ def __init__(self, backbone: nn.Module, bottleneck: Optional[nn.Module] = None, bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True, pool_layer=None): super(Regressor, self).__init__() self.backbone = backbone if pool_layer is None: self.pool_layer = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() ) else: self.pool_layer = pool_layer if bottleneck is None: self.bottleneck = nn.Identity() self._features_dim = backbone.out_features else: self.bottleneck = bottleneck assert bottleneck_dim > 0 self._features_dim = bottleneck_dim if head is None: self.head = nn.Linear(self._features_dim, 1) else: self.head = head self.finetune = finetune @property def features_dim(self) -> int: """The dimension of features before the final `head` layer""" return self._features_dim def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """""" f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) predictions = self.head(f) if self.training: return predictions, f else: return predictions def get_parameters(self, base_lr=1.0) -> List[Dict]: """A parameter list which decides optimization hyper-parameters, such as the relative learning rate of each layer """ params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head.parameters(), "lr": 1.0 * base_lr}, ] return params def get_dataset(dataset_name, root, unlabeled_list=("test_unlabeled",), test_list=("test",), split_scheme='official', transform_train=None, transform_test=None, use_unlabeled=True, verbose=True, **kwargs): labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, split_scheme=split_scheme, **kwargs) train_labeled_dataset = labeled_dataset.get_subset("train", transform=transform_train) if use_unlabeled: unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True) train_unlabeled_datasets = [ unlabeled_dataset.get_subset(u, transform=transform_train) for u in unlabeled_list ] train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets) else: unlabeled_list = [] train_unlabeled_datasets = [] train_unlabeled_dataset = None test_datasets = [ labeled_dataset.get_subset(t, transform=transform_test) for t in test_list ] num_channels = labeled_dataset.get_input(0).size()[0] if verbose: print("Datasets") for n, d in zip(["train"] + unlabeled_list + test_list, [train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets): print("\t{}:{}".format(n, len(d))) return train_labeled_dataset, train_unlabeled_dataset, test_datasets, num_channels def get_model_names(): return sorted(name for name in models.__dict__ if name.islower() and not name.startswith('__') and callable(models.__dict__[name])) def get_model(arch, num_channels): if arch in models.__dict__: model = models.__dict__[arch](num_channels=num_channels) else: raise ValueError('{} is not supported'.format(arch)) return model def collate_list(vec): """ Adapted from https://github.com/p-lambda/wilds If vec is a list of Tensors, it concatenates them all along the first dimension. If vec is a list of lists, it joins these lists together, but does not attempt to recursively collate. This allows each element of the list to be, e.g., its own dict. If vec is a list of dicts (with the same keys in each dict), it returns a single dict with the same keys. For each key, it recursively collates all entries in the list. """ if not isinstance(vec, list): raise TypeError("collate_list must take in a list") elem = vec[0] if torch.is_tensor(elem): return torch.cat(vec) elif isinstance(elem, list): return [obj for sublist in vec for obj in sublist] elif isinstance(elem, dict): return {k: collate_list([d[k] for d in vec]) for k in elem} else: raise TypeError("Elements of the list to collate must be tensors or dicts.") def reduce_tensor(tensor, world_size): rt = tensor.clone() dist.all_reduce(rt, op=dist.reduce_op.SUM) rt /= world_size return rt def validate(val_dataset, model, epoch, writer, args): val_sampler = None if args.distributed: val_sampler = DistributedSampler(val_dataset) val_loader = DataLoader( val_dataset, batch_size=args.batch_size[0], shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) all_y_true = [] all_y_pred = [] all_metadata = [] batch_time = AverageMeter('Time', ':6.3f') progress = ProgressMeter( len(val_loader), [batch_time], prefix='Test: ') # switch to evaluate mode model.eval() end = time.time() for i, (input, target, metadata) in enumerate(val_loader): # compute output with torch.no_grad(): output = model(input.cuda()).cpu() all_y_true.append(target) all_y_pred.append(output) all_metadata.append(metadata) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.local_rank == 0 and i % args.print_freq == 0: progress.display(i) if args.local_rank == 0: # evaluate results = val_dataset.eval( collate_list(all_y_pred), collate_list(all_y_true), collate_list(all_metadata) ) print(results[1]) for k, v in results[0].items(): if v == 0 or "Other" in k: continue writer.add_scalar("test/{}".format(k), v, global_step=epoch) return results[0][args.metric] ================================================ FILE: examples/domain_adaptation/wilds_text/README.md ================================================ # Unsupervised Domain Adaptation for WILDS (Text Classification) ## Installation It's suggested to use **pytorch==1.10.1** in order to reproduce the benchmark results. You need to run ``` pip install -r requirements.txt ``` ## Dataset Following datasets can be downloaded automatically: - [CivilComments (WILDS)](https://wilds.stanford.edu/datasets/) - [Amazon (WILDS)](https://wilds.stanford.edu/datasets/) ## Supported Methods TODO ## Usage The shell files give all the training scripts we use, e.g. ``` CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "civilcomments" --unlabeled-list "extra_unlabeled" \ --uniform-over-groups --groupby-fields y black --max-token-length 300 --lr 1e-05 --metric "acc_wg" \ --seed 0 --deterministic --log logs/erm/civilcomments ``` ## Results ### Performance on WILDS-CivilComments (DistilBert) | Methods | Val Avg Acc | Val Worst-Group Acc | Test Avg Acc | Test Worst-Group Acc | GPU Memory Usage(GB)| | --- | --- | --- | --- | --- | --- | | ERM | 89.2 | 67.7 | 88.9 | 68.5 | 6.4 | ### Performance on WILDS-Amazon (DistilBert) | Methods | Val Avg Acc | Test Avg Acc | Val 10% Acc | Test 10% Acc | GPU Memory Usage(GB)| | --- | --- | --- | --- | --- | --- | | ERM | 72.6 | 71.6 | 54.7 | 53.8 | 12.8 | ### Visualization We use tensorboard to record the training process and visualize the outputs of the models. ``` tensorboard --logdir=logs ``` #### WILDS-CivilComments #### WILDS-Amazon ================================================ FILE: examples/domain_adaptation/wilds_text/erm.py ================================================ """ @author: Jiaxin Li @contact: thulijx@gmail.com """ import argparse import shutil import time import pprint import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torch.utils.data.sampler import WeightedRandomSampler from torch.utils.tensorboard import SummaryWriter from transformers import AdamW, get_linear_schedule_with_warmup import wilds from wilds.common.grouper import CombinatorialGrouper import utils from tllib.utils.logger import CompleteLogger from tllib.utils.meter import AverageMeter from tllib.utils.metric import accuracy def main(args): logger = CompleteLogger(args.log, args.phase) writer = SummaryWriter(args.log) pprint.pprint(args) print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version())) cudnn.benchmark = True if args.deterministic: cudnn.benchmark = False cudnn.deterministic = True torch.manual_seed(args.seed) torch.set_printoptions(precision=10) # Data loading code train_transform = utils.get_transform(args.arch, args.max_token_length) val_transform = utils.get_transform(args.arch, args.max_token_length) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_labeled_dataset, train_unlabeled_dataset, test_datasets, labeled_dataset, args.num_classes, args.class_names = \ utils.get_dataset(args.data, args.data_dir, args.unlabeled_list, args.test_list, train_transform, val_transform, use_unlabeled=args.use_unlabeled, verbose=True) # create model print("=> using model '{}'".format(args.arch)) model = utils.get_model(args.arch, args.num_classes) model = model.cuda().to() # Data loading code train_labeled_sampler = None if args.uniform_over_groups: train_grouper = CombinatorialGrouper(dataset=labeled_dataset, groupby_fields=args.groupby_fields) groups, group_counts = train_grouper.metadata_to_group(train_labeled_dataset.metadata_array, return_counts=True) group_weights = 1 / group_counts weights = group_weights[groups] train_labeled_sampler = WeightedRandomSampler(weights, len(train_labeled_dataset), replacement=True) train_labeled_loader = DataLoader( train_labeled_dataset, batch_size=args.batch_size[0], shuffle=(train_labeled_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_labeled_sampler ) no_decay = ['bias', 'LayerNorm.weight'] decay_params = [] no_decay_params = [] for names, params in model.named_parameters(): if any(nd in names for nd in no_decay): no_decay_params.append(params) else: decay_params.append(params) params = [ {'params': decay_params, 'weight_decay': args.weight_decay}, {'params': no_decay_params, 'weight_decay': 0.0} ] optimizer = AdamW(params, lr=args.lr) lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_training_steps=len(train_labeled_loader) * args.epochs, num_warmup_steps=0) lr_scheduler.step_every_batch = True lr_scheduler.use_metric = False # define loss function (criterion) criterion = nn.CrossEntropyLoss().cuda() if args.phase == 'test': # resume from the latest checkpoint checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) for n, d in zip(args.test_list, test_datasets): print(n) utils.validate(d, model, -1, writer, args) return best_val_metric = 0 test_metric = 0 for epoch in range(args.epochs): lr_scheduler.step(epoch) print(lr_scheduler.get_last_lr()) writer.add_scalar("train/lr", lr_scheduler.get_last_lr()[-1], epoch) # train for one epoch train(train_labeled_loader, model, criterion, optimizer, epoch, writer, args) # evaluate on validation set for n, d in zip(args.test_list, test_datasets): print(n) if n == 'val': tmp_val_metric = utils.validate(d, model, epoch, writer, args) elif n == 'test': tmp_test_metric = utils.validate(d, model, epoch, writer, args) # remember best prec@1 and save checkpoint is_best = tmp_val_metric > best_val_metric best_val_metric = max(tmp_val_metric, best_val_metric) torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if is_best: test_metric = tmp_test_metric shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) print('best val performance: {:.3f}'.format(best_val_metric)) print('test performance: {:.3f}'.format(test_metric)) logger.close() writer.close() def train(train_loader, model, criterion, optimizer, epoch, writer, args): batch_time = AverageMeter('Time', ':3.1f') losses = AverageMeter('Loss', ':3.2f') top1 = AverageMeter('Top 1', ':3.1f') # switch to train mode model.train() end = time.time() for i, (input, target, metadata) in enumerate(train_loader): # compute output output = model(input.cuda()) loss = criterion(output, target.cuda()) # compute gradient and do optimizer step optimizer.zero_grad() loss.backward() optimizer.step() if i % args.print_freq == 0: # Every print_freq iterations, check the loss, accuracy, and speed. # For best performance, it doesn't make sense to print these metrics every # iteration, since they incur an allreduce and some host<->device syncs. # Measure accuracy prec1, = accuracy(output.data, target.cuda(), topk=(1,)) losses.update(loss, input.size(0)) top1.update(prec1, input.size(0)) global_step = epoch * len(train_loader) + i batch_time.update((time.time() - end) / args.print_freq) end = time.time() writer.add_scalar("train/loss", loss, global_step) print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Speed {3:.3f} ({4:.3f})\t' 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, i, len(train_loader), args.batch_size[0] / batch_time.val, args.batch_size[0] / batch_time.avg, batch_time=batch_time, loss=losses, top1=top1)) if __name__ == '__main__': model_names = utils.get_model_names() parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='civilcomments', choices=wilds.supported_datasets, help='dataset: ' + ' | '.join(wilds.supported_datasets) + ' (default: civilcomments)') parser.add_argument('--unlabeled-list', nargs='+', default=[]) parser.add_argument('--test-list', nargs='+', default=["val", "test"]) parser.add_argument('--metric', default='acc_wg', help='metric used to evaluate model performance. (default: worst group accuracy)') parser.add_argument('--uniform-over-groups', action='store_true', help='sample examples such that batches are uniform over groups') parser.add_argument('--groupby-fields', nargs='+', help='Group data by given fields. It means that items which have the same' 'values in those fields should be grouped.') parser.add_argument('--use-unlabeled', action='store_true', help='Whether use unlabeled data for training or not.') # model parameters parser.add_argument('--arch', '-a', metavar='ARCH', default='distilbert-base-uncased', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: distilbert-base-uncased)') parser.add_argument('--max-token-length', type=int, default=300, help='The maximum size of a sequence.') # Learning rate schedule parameters parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, metavar='LR', help='Learning rate.') parser.add_argument('--weight-decay', '--wd', default=0.01, type=float, metavar='W', help='weight decay (default: 0.01)') # training parameters parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=5, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=(16, 16), type=int, nargs='+', metavar='N', help='mini-batch size per process for source' ' and target domain (default: (16, 16))') parser.add_argument('--print-freq', '-p', default=200, type=int, metavar='N', help='print frequency (default: 200)') parser.add_argument('--deterministic', action='store_true') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument('--log', type=str, default='src_only', help='Where to save logs, checkpoints and debugging images.') parser.add_argument('--phase', type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis'm only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_adaptation/wilds_text/erm.sh ================================================ # civilcomments CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "civilcomments" --unlabeled-list "extra_unlabeled" \ --uniform-over-groups --groupby-fields y black --max-token-length 300 --lr 1e-05 --metric "acc_wg" \ --seed 0 --deterministic --log logs/erm/civilcomments # amazon CUDA_VISIBLE_DEVICES=0 python erm.py data/wilds -d "amazon" --max-token-length 512 \ --lr 1e-5 -b 24 24 --epochs 3 --metric "10th_percentile_acc" --seed 0 --deterministic --log logs/erm/amazon ================================================ FILE: examples/domain_adaptation/wilds_text/requirements.txt ================================================ wilds tensorflow tensorboard transformers ================================================ FILE: examples/domain_adaptation/wilds_text/utils.py ================================================ """ @author: Jiaxin Li @contact: thulijx@gmail.com """ import time import sys import torch import torch.distributed as dist from torch.utils.data import DataLoader, ConcatDataset from transformers import DistilBertTokenizerFast from transformers import DistilBertForSequenceClassification import wilds sys.path.append('../../..') from tllib.utils.meter import AverageMeter, ProgressMeter class DistilBertClassifier(DistilBertForSequenceClassification): """ Adapted from https://github.com/p-lambda/wilds """ def __call__(self, x): input_ids = x[:, :, 0] attention_mask = x[:, :, 1] outputs = super().__call__( input_ids=input_ids, attention_mask=attention_mask, )[0] return outputs def get_transform(arch, max_token_length): """ Adapted from https://github.com/p-lambda/wilds """ if arch == 'distilbert-base-uncased': tokenizer = DistilBertTokenizerFast.from_pretrained(arch) else: raise ValueError("Model: {arch} not recognized".format(arch)) def transform(text): tokens = tokenizer(text, padding='max_length', truncation=True, max_length=max_token_length, return_tensors='pt') if arch == 'bert_base_uncased': x = torch.stack( ( tokens["input_ids"], tokens["attention_mask"], tokens["token_type_ids"], ), dim=2, ) elif arch == 'distilbert-base-uncased': x = torch.stack((tokens["input_ids"], tokens["attention_mask"]), dim=2) x = torch.squeeze(x, dim=0) # First shape dim is always 1 return x return transform def get_dataset(dataset_name, root, unlabeled_list=('extra_unlabeled',), test_list=('test',), transform_train=None, transform_test=None, use_unlabeled=True, verbose=True): labeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True) train_labeled_dataset = labeled_dataset.get_subset('train', transform=transform_train) if use_unlabeled: unlabeled_dataset = wilds.get_dataset(dataset_name, root_dir=root, download=True, unlabeled=True) train_unlabeled_datasets = [ unlabeled_dataset.get_subset(u, transform=transform_train) for u in unlabeled_list ] train_unlabeled_dataset = ConcatDataset(train_unlabeled_datasets) else: unlabeled_list = [] train_unlabeled_datasets = [] train_unlabeled_dataset = None test_datasets = [ labeled_dataset.get_subset(t, transform=transform_test) for t in test_list ] num_classes = labeled_dataset.n_classes class_names = list(range(num_classes)) if verbose: print('Datasets') for n, d in zip(['train'] + unlabeled_list + test_list, [train_labeled_dataset, ] + train_unlabeled_datasets + test_datasets): print('\t{}:{}'.format(n, len(d))) print('\t#classes:', num_classes) return train_labeled_dataset, train_unlabeled_dataset, test_datasets, labeled_dataset, num_classes, class_names def get_model_names(): return ['distilbert-base-uncased'] def get_model(arch, num_classes): if arch == 'distilbert-base-uncased': model = DistilBertClassifier.from_pretrained(arch, num_labels=num_classes) else: raise ValueError('{} is not supported'.format(arch)) return model def reduce_tensor(tensor, world_size): rt = tensor.clone() dist.all_reduce(rt, op=dist.reduce_op.SUM) rt /= world_size return rt def collate_list(vec): """ Adapted from https://github.com/p-lambda/wilds If vec is a list of Tensors, it concatenates them all along the first dimension. If vec is a list of lists, it joins these lists together, but does not attempt to recursively collate. This allows each element of the list to be, e.g., its own dict. If vec is a list of dicts (with the same keys in each dict), it returns a single dict with the same keys. For each key, it recursively collates all entries in the list. """ if not isinstance(vec, list): raise TypeError("collate_list must take in a list") elem = vec[0] if torch.is_tensor(elem): return torch.cat(vec) elif isinstance(elem, list): return [obj for sublist in vec for obj in sublist] elif isinstance(elem, dict): return {k: collate_list([d[k] for d in vec]) for k in elem} else: raise TypeError("Elements of the list to collate must be tensors or dicts.") def validate(val_dataset, model, epoch, writer, args): val_sampler = None val_loader = DataLoader( val_dataset, batch_size=args.batch_size[0], shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler) all_y_true = [] all_y_pred = [] all_metadata = [] batch_time = AverageMeter('Time', ':6.3f') progress = ProgressMeter( len(val_loader), [batch_time], prefix='Test: ') # switch to evaluate mode model.eval() end = time.time() for i, (input, target, metadata) in enumerate(val_loader): # compute output with torch.no_grad(): output = model(input.cuda()).cpu() all_y_true.append(target) all_y_pred.append(output.argmax(1)) all_metadata.append(metadata) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if args.local_rank == 0 and i % args.print_freq == 0: progress.display(i) # evaluate results = val_dataset.eval( collate_list(all_y_pred), collate_list(all_y_true), collate_list(all_metadata) ) print(results[1]) for k, v in results[0].items(): if v == 0 or "Other" in k: continue writer.add_scalar("test/{}".format(k), v, global_step=epoch) return results[0][args.metric] ================================================ FILE: examples/domain_generalization/image_classification/README.md ================================================ # Domain Generalization for Image Classification ## Installation It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results. Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You also need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset Following datasets can be downloaded automatically: - [Office31](https://www.cc.gatech.edu/~judy/domainadapt/) - [OfficeHome](https://www.hemanthdv.org/officeHomeDataset.html) - [DomainNet](http://ai.bu.edu/M3SDA/) - [PACS](https://domaingeneralization.github.io/#data) ## Supported Methods - [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf) - [Domain Generalization with MixStyle (MixStyle, 2021 ICLR)](https://arxiv.org/abs/2104.02008) - [Learning to Generalize: Meta-Learning for Domain Generalization (MLDG, 2018 AAAI)](https://arxiv.org/pdf/1710.03463.pdf) - [Invariant Risk Minimization (IRM)](https://arxiv.org/abs/1907.02893) - [Out-of-Distribution Generalization via Risk Extrapolation (VREx, 2021 ICML)](https://arxiv.org/abs/2003.00688) - [Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (GroupDRO)](https://arxiv.org/abs/1911.08731) - [Deep CORAL: Correlation Alignment for Deep Domain Adaptation (Deep Coral, 2016 ECCV)](https://arxiv.org/abs/1607.01719) ## Usage The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to train IRM on Office-Home, use the following script ```shell script # Train with IRM on Office-Home Ar Cl Rw -> Pr task using ResNet 50. # Assume you have put the datasets under the path `data/office-home`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr ``` Note that ``-s`` specifies the source domain, ``-t`` specifies the target domain, and ``--log`` specifies where to store results. ## Experiment and Results Following [DomainBed](https://github.com/facebookresearch/DomainBed), we select hyper-parameters based on the model's performance on `training-domain validation set` (first rule in DomainBed). Concretely, we save model with the highest accuracy on `training-domain validation set` and then load this checkpoint to test on the target domain. Here are some differences between our implementation and DomainBed. For the model, we do not freeze `BatchNorm2d` layers and do not insert additional `Dropout` layer except for `PACS` dataset. For the optimizer, we use `SGD` with momentum by default and find this usually achieves better performance than `Adam`. **Notations** - ``ERM`` refers to the model trained with data from the source domain. - ``Avg`` is the accuracy reported by `TLlib`. ### PACS accuracy on ResNet-50 | Methods | avg | A | C | P | S | |----------|------|------|------|------|------| | ERM | 86.4 | 88.5 | 78.4 | 97.2 | 81.4 | | IBN | 87.8 | 88.2 | 84.5 | 97.1 | 81.4 | | MixStyle | 87.4 | 87.8 | 82.3 | 95.0 | 84.5 | | MLDG | 87.2 | 88.2 | 81.4 | 96.6 | 82.5 | | IRM | 86.9 | 88.0 | 82.5 | 98.0 | 79.0 | | VREx | 87.0 | 87.2 | 82.3 | 97.4 | 81.0 | | GroupDRO | 87.3 | 88.9 | 81.7 | 97.8 | 80.8 | | CORAL | 86.4 | 89.1 | 80.0 | 97.4 | 79.1 | ### Office-Home accuracy on ResNet-50 | Methods | avg | A | C | P | R | |----------|------|------|------|------|------| | ERM | 70.8 | 68.3 | 55.9 | 78.9 | 80.0 | | IBN | 69.9 | 67.4 | 55.2 | 77.3 | 79.6 | | MixStyle | 71.7 | 66.8 | 58.1 | 78.0 | 79.9 | | MLDG | 70.3 | 65.9 | 57.6 | 78.2 | 79.6 | | IRM | 70.3 | 66.7 | 54.8 | 78.6 | 80.9 | | VREx | 70.2 | 66.9 | 54.9 | 78.2 | 80.9 | | GroupDRO | 70.0 | 66.7 | 55.2 | 78.8 | 79.9 | | CORAL | 70.9 | 68.3 | 55.4 | 78.8 | 81.0 | ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{IBN-Net, author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang}, title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net}, booktitle = {ECCV}, year = {2018} } @inproceedings{mixstyle, title={Domain Generalization with MixStyle}, author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao}, booktitle={ICLR}, year={2021} } @inproceedings{MLDG, title={Learning to Generalize: Meta-Learning for Domain Generalization}, author={Li, Da and Yang, Yongxin and Song, Yi-Zhe and Hospedales, Timothy}, booktitle={AAAI}, year={2018} } @misc{IRM, title={Invariant Risk Minimization}, author={Martin Arjovsky and Léon Bottou and Ishaan Gulrajani and David Lopez-Paz}, year={2020}, eprint={1907.02893}, archivePrefix={arXiv}, primaryClass={stat.ML} } @inproceedings{VREx, title={Out-of-Distribution Generalization via Risk Extrapolation (REx)}, author={David Krueger and Ethan Caballero and Joern-Henrik Jacobsen and Amy Zhang and Jonathan Binas and Dinghuai Zhang and Remi Le Priol and Aaron Courville}, year={2021}, booktitle={ICML}, } @inproceedings{GroupDRO, title={Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization}, author={Shiori Sagawa and Pang Wei Koh and Tatsunori B. Hashimoto and Percy Liang}, year={2020}, booktitle={ICLR} } @inproceedings{deep_coral, title={Deep coral: Correlation alignment for deep domain adaptation}, author={Sun, Baochen and Saenko, Kate}, booktitle={ECCV}, year={2016}, } ``` ================================================ FILE: examples/domain_generalization/image_classification/coral.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.alignment.coral import CorrelationAlignmentLoss from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=True, random_gray_scale=True) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='train', download=True, transform=train_transform, seed=args.seed) sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=args.n_domains_per_batch) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, drop_last=True) val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val', download=True, transform=val_transform, seed=args.seed) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test', download=True, transform=val_transform, seed=args.seed) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("train_dataset_size: ", len(train_dataset)) print('val_dataset_size: ', len(val_dataset)) print("test_dataset_size: ", len(test_dataset)) train_iter = ForeverDataIterator(train_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, finetune=args.finetune, pool_layer=pool_layer).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch) # define loss function correlation_alignment_loss = CorrelationAlignmentLoss().to(device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100) target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100) print(len(source_feature), len(target_feature)) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_val_acc1 = 0. best_test_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, lr_scheduler, correlation_alignment_loss, args.n_domains_per_batch, epoch, args) # evaluate on validation set print("Evaluate on validation set...") acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_val_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_acc1 = max(acc1, best_val_acc1) # evaluate on test set print("Evaluate on test set...") best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test acc on test set = {}".format(acc1)) print("oracle acc on test set = {}".format(best_test_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, correlation_alignment_loss: CorrelationAlignmentLoss, n_domains_per_batch: int, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') losses_ce = AverageMeter('CELoss', ':3.2f') losses_penalty = AverageMeter('Penalty Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_all, labels_all, _ = next(train_iter) x_all = x_all.to(device) labels_all = labels_all.to(device) # compute output y_all, f_all = model(x_all) # measure data loading time data_time.update(time.time() - end) # separate into different domains y_all = y_all.chunk(n_domains_per_batch, dim=0) f_all = f_all.chunk(n_domains_per_batch, dim=0) labels_all = labels_all.chunk(n_domains_per_batch, dim=0) loss_ce = 0 loss_penalty = 0 cls_acc = 0 for domain_i in range(n_domains_per_batch): # cls loss y_i, labels_i = y_all[domain_i], labels_all[domain_i] loss_ce += F.cross_entropy(y_i, labels_i) # update acc cls_acc += accuracy(y_i, labels_i)[0] / n_domains_per_batch # correlation alignment loss for domain_j in range(domain_i + 1, n_domains_per_batch): f_i = f_all[domain_i] f_j = f_all[domain_j] loss_penalty += correlation_alignment_loss(f_i, f_j) # normalize loss loss_ce /= n_domains_per_batch loss_penalty /= n_domains_per_batch * (n_domains_per_batch - 1) / 2 loss = loss_ce + loss_penalty * args.trade_off losses.update(loss.item(), x_all.size(0)) losses_ce.update(loss_ce.item(), x_all.size(0)) losses_penalty.update(loss_penalty.item(), x_all.size(0)) cls_accs.update(cls_acc.item(), x_all.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='CORAL for Domain Generalization') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='PACS', help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: PACS)') parser.add_argument('-s', '--sources', nargs='+', default=None, help='source domain(s)') parser.add_argument('-t', '--targets', nargs='+', default=None, help='target domain(s)') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers') parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True') # training parameters parser.add_argument('--trade-off', default=1, type=float, help='the trade off hyper parameter for correlation alignment loss') parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--n-domains-per-batch', default=3, type=int, help='number of domains in each mini-batch') parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='coral', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/image_classification/coral.sh ================================================ #!/usr/bin/env bash # ResNet50, PACS CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_P CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_A CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_C CUDA_VISIBLE_DEVICES=0 python coral.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/coral/PACS_S # ResNet50, Office-Home CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/coral/OfficeHome_Pr CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/coral/OfficeHome_Rw CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/coral/OfficeHome_Cl CUDA_VISIBLE_DEVICES=0 python coral.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/coral/OfficeHome_Ar # ResNet50, DomainNet CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_c CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_i CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_p CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_q CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_r CUDA_VISIBLE_DEVICES=0 python coral.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/coral/DomainNet_s ================================================ FILE: examples/domain_generalization/image_classification/erm.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=True, random_gray_scale=True) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='train', download=True, transform=train_transform, seed=args.seed) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val', download=True, transform=val_transform, seed=args.seed) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test', download=True, transform=val_transform, seed=args.seed) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("train_dataset_size: ", len(train_dataset)) print('val_dataset_size: ', len(val_dataset)) print("test_dataset_size: ", len(test_dataset)) train_iter = ForeverDataIterator(train_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, finetune=args.finetune, pool_layer=pool_layer).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100) target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100) print(len(source_feature), len(target_feature)) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_val_acc1 = 0. best_test_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set print("Evaluate on validation set...") acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_val_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_acc1 = max(acc1, best_val_acc1) # evaluate on test set print("Evaluate on test set...") best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test acc on test set = {}".format(acc1)) print("oracle acc on test set = {}".format(best_test_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels, _ = next(train_iter) x = x.to(device) labels = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, _ = model(x) loss = F.cross_entropy(y, labels) cls_acc = accuracy(y, labels)[0] losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Baseline for Domain Generalization') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='PACS', help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: PACS)') parser.add_argument('-s', '--sources', nargs='+', default=None, help='source domain(s)') parser.add_argument('-t', '--targets', nargs='+', default=None, help='target domain(s)') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers') parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=0, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='baseline', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/image_classification/erm.sh ================================================ #!/usr/bin/env bash # ResNet50, PACS CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_P CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_A CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_C CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/erm/PACS_S # ResNet50, Office-Home CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/erm/OfficeHome_Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/erm/OfficeHome_Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/erm/OfficeHome_Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/erm/OfficeHome_Ar # ResNet50, DomainNet CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_c CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_i CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_p CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_q CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_r CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_s ================================================ FILE: examples/domain_generalization/image_classification/groupdro.py ================================================ """ Adapted from https://github.com/facebookresearch/DomainBed @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.reweight.groupdro import AutomaticUpdateDomainWeightModule from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=True, random_gray_scale=True) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='train', download=True, transform=train_transform, seed=args.seed) sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, args.n_domains_per_batch) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, drop_last=True) val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val', download=True, transform=val_transform, seed=args.seed) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test', download=True, transform=val_transform, seed=args.seed) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("train_dataset_size: ", len(train_dataset)) print('val_dataset_size: ', len(val_dataset)) print("test_dataset_size: ", len(test_dataset)) train_iter = ForeverDataIterator(train_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, finetune=args.finetune, pool_layer=pool_layer).to(device) num_all_domains = len(train_dataset.datasets) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch) domain_weight_module = AutomaticUpdateDomainWeightModule(num_all_domains, args.eta, device) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100) target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100) print(len(source_feature), len(target_feature)) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_val_acc1 = 0. best_test_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, lr_scheduler, domain_weight_module, args.n_domains_per_batch, epoch, args) # evaluate on validation set print("Evaluate on validation set...") acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_val_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_acc1 = max(acc1, best_val_acc1) # evaluate on test set print("Evaluate on test set...") best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test acc on test set = {}".format(acc1)) print("oracle acc on test set = {}".format(best_test_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, domain_weight_module: AutomaticUpdateDomainWeightModule, n_domains_per_batch: int, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_all, labels_all, domain_labels = next(train_iter) x_all = x_all.to(device) labels_all = labels_all.to(device) domain_labels = domain_labels.to(device) # get selected domain idxes domain_labels = domain_labels.chunk(n_domains_per_batch, dim=0) sampled_domain_idxes = [domain_labels[i][0].item() for i in range(n_domains_per_batch)] # measure data loading time data_time.update(time.time() - end) loss_per_domain = torch.zeros(n_domains_per_batch).to(device) cls_acc = 0 for domain_id, (x_per_domain, labels_per_domain) in enumerate( zip(x_all.chunk(n_domains_per_batch, dim=0), labels_all.chunk(n_domains_per_batch, dim=0))): y_per_domain, _ = model(x_per_domain) loss_per_domain[domain_id] = F.cross_entropy(y_per_domain, labels_per_domain) cls_acc += accuracy(y_per_domain, labels_per_domain)[0] / n_domains_per_batch # update domain weight domain_weight_module.update(loss_per_domain, sampled_domain_idxes) domain_weight = domain_weight_module.get_domain_weight(sampled_domain_idxes) # weighted cls loss loss = (loss_per_domain * domain_weight).sum() losses.update(loss.item(), x_all.size(0)) cls_accs.update(cls_acc.item(), x_all.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='GroupDRO for Domain Generalization') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='PACS', help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: PACS)') parser.add_argument('-s', '--sources', nargs='+', default=None, help='source domain(s)') parser.add_argument('-t', '--targets', nargs='+', default=None, help='target domain(s)') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers') parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True') # training parameters parser.add_argument('--eta', default=1e-2, type=float, help='the eta hyper parameter') parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--n-domains-per-batch', default=3, type=int, help='number of domains in each mini-batch') parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='groupdro', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/image_classification/groupdro.sh ================================================ #!/usr/bin/env bash # ResNet50, PACS CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_P CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_A CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_C CUDA_VISIBLE_DEVICES=0 python groupdro.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/groupdro/PACS_S # ResNet50, Office-Home CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Pr CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Rw CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Cl CUDA_VISIBLE_DEVICES=0 python groupdro.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/groupdro/OfficeHome_Ar # ResNet50, DomainNet CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_c CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_i CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_p CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_q CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_r CUDA_VISIBLE_DEVICES=0 python groupdro.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --lr 0.005 --seed 0 --log logs/groupdro/DomainNet_s ================================================ FILE: examples/domain_generalization/image_classification/ibn.sh ================================================ #!/usr/bin/env bash # IBN_ResNet50_b, PACS CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s A C S -t P -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_P CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P C S -t A -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_A CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A S -t C -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_C CUDA_VISIBLE_DEVICES=0 python erm.py data/PACS -d PACS -s P A C -t S -a resnet50_ibn_b --freeze-bn --seed 0 --log logs/erm/PACS_S # IBN_ResNet50_b, Office-Home CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Pr CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Rw CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Cl CUDA_VISIBLE_DEVICES=0 python erm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50_ibn_b --seed 0 --log logs/erm/OfficeHome_Ar # IBN_ResNet50_b, DomainNet CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_c CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_i CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_p CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_q CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_r CUDA_VISIBLE_DEVICES=0 python erm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50_ibn_b -i 2500 --lr 0.01 --seed 0 --log logs/erm/DomainNet_s ================================================ FILE: examples/domain_generalization/image_classification/irm.py ================================================ """ Adapted from https://github.com/facebookresearch/DomainBed @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader import torch.nn.functional as F import torch.autograd as autograd import utils from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class InvariancePenaltyLoss(nn.Module): r"""Invariance Penalty Loss from `Invariant Risk Minimization `_. We adopt implementation from `DomainBed `_. Given classifier output :math:`y` and ground truth :math:`labels`, we split :math:`y` into two parts :math:`y_1, y_2`, corresponding labels are :math:`labels_1, labels_2`. Next we calculate cross entropy loss with respect to a dummy classifier :math:`w`, resulting in :math:`grad_1, grad_2` . Invariance penalty is then :math:`grad_1*grad_2`. Inputs: - y: predictions from model - labels: ground truth Shape: - y: :math:`(N, C)` where C means the number of classes. - labels: :math:`(N, )` where N mean mini-batch size """ def __init__(self): super(InvariancePenaltyLoss, self).__init__() self.scale = torch.tensor(1.).requires_grad_() def forward(self, y: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: loss_1 = F.cross_entropy(y[::2] * self.scale, labels[::2]) loss_2 = F.cross_entropy(y[1::2] * self.scale, labels[1::2]) grad_1 = autograd.grad(loss_1, [self.scale], create_graph=True)[0] grad_2 = autograd.grad(loss_2, [self.scale], create_graph=True)[0] penalty = torch.sum(grad_1 * grad_2) return penalty def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=True, random_gray_scale=True) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='train', download=True, transform=train_transform, seed=args.seed) sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=args.n_domains_per_batch) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, drop_last=True) val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val', download=True, transform=val_transform, seed=args.seed) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test', download=True, transform=val_transform, seed=args.seed) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("train_dataset_size: ", len(train_dataset)) print('val_dataset_size: ', len(val_dataset)) print("test_dataset_size: ", len(test_dataset)) train_iter = ForeverDataIterator(train_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, finetune=args.finetune, pool_layer=pool_layer).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch) # define loss function invariance_penalty_loss = InvariancePenaltyLoss().to(device) # for simplicity assert args.anneal_iters % args.iters_per_epoch == 0 # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100) target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100) print(len(source_feature), len(target_feature)) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_val_acc1 = 0. best_test_acc1 = 0. for epoch in range(args.epochs): if epoch * args.iters_per_epoch == args.anneal_iters: # reset optimizer to avoid sharp jump in gradient magnitudes optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch - args.anneal_iters) print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, lr_scheduler, invariance_penalty_loss, args.n_domains_per_batch, epoch, args) # evaluate on validation set print("Evaluate on validation set...") acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_val_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_acc1 = max(acc1, best_val_acc1) # evaluate on test set print("Evaluate on test set...") best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test acc on test set = {}".format(acc1)) print("oracle acc on test set = {}".format(best_test_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, invariance_penalty_loss: InvariancePenaltyLoss, n_domains_per_batch: int, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') losses_ce = AverageMeter('CELoss', ':3.2f') losses_penalty = AverageMeter('Penalty Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_all, labels_all, _ = next(train_iter) x_all = x_all.to(device) labels_all = labels_all.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_all, _ = model(x_all) # cls loss loss_ce = F.cross_entropy(y_all, labels_all) # penalty loss loss_penalty = 0 for y_per_domain, labels_per_domain in zip(y_all.chunk(n_domains_per_batch, dim=0), labels_all.chunk(n_domains_per_batch, dim=0)): # normalize loss by domain num loss_penalty += invariance_penalty_loss(y_per_domain, labels_per_domain) / n_domains_per_batch global_iter = epoch * args.iters_per_epoch + i if global_iter >= args.anneal_iters: trade_off = args.trade_off else: trade_off = 1 loss = loss_ce + loss_penalty * trade_off cls_acc = accuracy(y_all, labels_all)[0] losses.update(loss.item(), x_all.size(0)) losses_ce.update(loss_ce.item(), x_all.size(0)) losses_penalty.update(loss_penalty.item(), x_all.size(0)) cls_accs.update(cls_acc.item(), x_all.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='IRM for Domain Generalization') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='PACS', help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: PACS)') parser.add_argument('-s', '--sources', nargs='+', default=None, help='source domain(s)') parser.add_argument('-t', '--targets', nargs='+', default=None, help='target domain(s)') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers') parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True') # training parameters parser.add_argument('--trade-off', default=1, type=float, help='the trade off hyper parameter for irm penalty') parser.add_argument('--anneal-iters', default=500, type=int, help='anneal iterations (trade off is set to 1 during these iterations)') parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--n-domains-per-batch', default=3, type=int, help='number of domains in each mini-batch') parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='irm', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/image_classification/irm.sh ================================================ #!/usr/bin/env bash # ResNet50, PACS CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_P CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_A CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_C CUDA_VISIBLE_DEVICES=0 python irm.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/irm/PACS_S # ResNet50, Office-Home CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/irm/OfficeHome_Pr CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/irm/OfficeHome_Rw CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/irm/OfficeHome_Cl CUDA_VISIBLE_DEVICES=0 python irm.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/irm/OfficeHome_Ar # ResNet50, DomainNet CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_c CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_i CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_p CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_q CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_r CUDA_VISIBLE_DEVICES=0 python irm.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --seed 0 --log logs/irm/DomainNet_s ================================================ FILE: examples/domain_generalization/image_classification/mixstyle.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils import tllib.normalization.mixstyle.resnet as models from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=True, random_gray_scale=True) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='train', download=True, transform=train_transform, seed=args.seed) sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=2) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, drop_last=True) val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val', download=True, transform=val_transform, seed=args.seed) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test', download=True, transform=val_transform, seed=args.seed) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("train_dataset_size: ", len(train_dataset)) print('val_dataset_size: ', len(val_dataset)) print("test_dataset_size: ", len(test_dataset)) train_iter = ForeverDataIterator(train_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = models.__dict__[args.arch](mix_layers=args.mix_layers, mix_p=args.mix_p, mix_alpha=args.mix_alpha, pretrained=True) pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, finetune=args.finetune, pool_layer=pool_layer).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100) target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100) print(len(source_feature), len(target_feature)) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_val_acc1 = 0. best_test_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set print("Evaluate on validation set...") acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_val_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_acc1 = max(acc1, best_val_acc1) # evaluate on test set print("Evaluate on test set...") best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test acc on test set = {}".format(acc1)) print("oracle acc on test set = {}".format(best_test_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels, _ = next(train_iter) x = x.to(device) labels = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, _ = model(x) cls_loss = F.cross_entropy(y, labels) loss = cls_loss cls_acc = accuracy(y, labels)[0] losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) parser = argparse.ArgumentParser(description='MixStyle for Domain Generalization') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='PACS', help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: PACS)') parser.add_argument('-s', '--sources', nargs='+', default=None, help='source domain(s)') parser.add_argument('-t', '--targets', nargs='+', default=None, help='target domain(s)') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--mix-layers', nargs='+', help='layers to apply MixStyle') parser.add_argument('--mix-p', default=0.5, type=float, help='probability to apply MixStyle') parser.add_argument('--mix-alpha', default=0.1, type=float, help='parameter alpha for beta distribution') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers') parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True') # training parameters parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='mixstyle', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/image_classification/mixstyle.sh ================================================ #!/usr/bin/env bash # ResNet50, PACS CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s A C S -t P -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_P CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P C S -t A -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_A CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P A S -t C -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_C CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/PACS -d PACS -s P A C -t S -a resnet50 --mix-layers layer1 layer2 layer3 --freeze-bn --seed 0 --log logs/mixstyle/PACS_S # ResNet50, Office-Home CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Pr CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Rw CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Cl CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --mix-layers layer1 layer2 --seed 0 --log logs/mixstyle/OfficeHome_Ar # ResNet50, DomainNet CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_c CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_i CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_p CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_q CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_r CUDA_VISIBLE_DEVICES=0 python mixstyle.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 2500 --lr 0.01 --seed 0 --log logs/mixstyle/DomainNet_s ================================================ FILE: examples/domain_generalization/image_classification/mldg.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader import torch.nn.functional as F import higher import utils from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=True, random_gray_scale=True) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='train', download=True, transform=train_transform, seed=args.seed) n_domains_per_batch = args.n_support_domains + args.n_query_domains sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, n_domains_per_batch=n_domains_per_batch) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, drop_last=True) val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val', download=True, transform=val_transform, seed=args.seed) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test', download=True, transform=val_transform, seed=args.seed) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("train_dataset_size: ", len(train_dataset)) print('val_dataset_size: ', len(val_dataset)) print("test_dataset_size: ", len(test_dataset)) train_iter = ForeverDataIterator(train_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, finetune=args.finetune, pool_layer=pool_layer).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100) target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100) print(len(source_feature), len(target_feature)) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_val_acc1 = 0. best_test_acc1 = 0. for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, lr_scheduler, epoch, n_domains_per_batch, args) # evaluate on validation set print("Evaluate on validation set...") acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_val_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_acc1 = max(acc1, best_val_acc1) # evaluate on test set print("Evaluate on test set...") best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test acc on test set = {}".format(acc1)) print("oracle acc on test set = {}".format(best_test_acc1)) logger.close() def random_split(x_list, labels_list, n_domains_per_batch, n_support_domains): assert n_support_domains < n_domains_per_batch support_domain_idxes = random.sample(range(n_domains_per_batch), n_support_domains) support_domain_list = [(x_list[idx], labels_list[idx]) for idx in range(n_domains_per_batch) if idx in support_domain_idxes] query_domain_list = [(x_list[idx], labels_list[idx]) for idx in range(n_domains_per_batch) if idx not in support_domain_idxes] return support_domain_list, query_domain_list def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, epoch: int, n_domains_per_batch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels, _ = next(train_iter) x = x.to(device) labels = labels.to(device) # measure data loading time data_time.update(time.time() - end) # split into support domain and query domain x_list = x.chunk(n_domains_per_batch, dim=0) labels_list = labels.chunk(n_domains_per_batch, dim=0) support_domain_list, query_domain_list = random_split(x_list, labels_list, n_domains_per_batch, args.n_support_domains) # clear grad optimizer.zero_grad() # compute output with higher.innerloop_ctx(model, optimizer, copy_initial_weights=False) as (inner_model, inner_optimizer): # perform inner optimization for _ in range(args.inner_iters): loss_inner = 0 for (x_s, labels_s) in support_domain_list: y_s, _ = inner_model(x_s) # normalize loss by support domain num loss_inner += F.cross_entropy(y_s, labels_s) / args.n_support_domains inner_optimizer.step(loss_inner) # calculate outer loss loss_outer = 0 cls_acc = 0 # loss on support domains for (x_s, labels_s) in support_domain_list: y_s, _ = model(x_s) # normalize loss by support domain num loss_outer += F.cross_entropy(y_s, labels_s) / args.n_support_domains # loss on query domains for (x_q, labels_q) in query_domain_list: y_q, _ = inner_model(x_q) # normalize loss by query domain num loss_outer += F.cross_entropy(y_q, labels_q) * args.trade_off / args.n_query_domains cls_acc += accuracy(y_q, labels_q)[0] / args.n_query_domains # update statistics losses.update(loss_outer.item(), args.batch_size) cls_accs.update(cls_acc.item(), args.batch_size) # compute gradient and do SGD step loss_outer.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Meta Learning for Domain Generalization') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='PACS', help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: PACS)') parser.add_argument('-s', '--sources', nargs='+', default=None, help='source domain(s)') parser.add_argument('-t', '--targets', nargs='+', default=None, help='target domain(s)') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers') parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True') # training parameters parser.add_argument('--n-support-domains', type=int, default=1, help='Number of support domains sampled in each iteration') parser.add_argument('--n-query-domains', type=int, default=2, help='Number of query domains in each iteration') parser.add_argument('--trade-off', type=float, default=1, help='hyper parameter beta') parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('--inner-iters', default=1, type=int, help='Number of iterations in inner loop') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='mldg', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/image_classification/mldg.sh ================================================ #!/usr/bin/env bash # ResNet50, PACS CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_P CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_A CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_C CUDA_VISIBLE_DEVICES=0 python mldg.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/mldg/PACS_S # ResNet50, Office-Home CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Pr CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Rw CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Cl CUDA_VISIBLE_DEVICES=0 python mldg.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/mldg/OfficeHome_Ar # ResNet50, DomainNet CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_c CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_i CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_p CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_q CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_r CUDA_VISIBLE_DEVICES=0 python mldg.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 5000 --lr 0.005 --seed 0 --log logs/mldg/DomainNet_s ================================================ FILE: examples/domain_generalization/image_classification/requirements.txt ================================================ timm wilds higher ================================================ FILE: examples/domain_generalization/image_classification/utils.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import copy import random import sys import time import timm import tqdm import torch import torch.nn as nn import torchvision.transforms as T import torch.nn.functional as F import numpy as np from torch.utils.data import Sampler, Subset, ConcatDataset sys.path.append('../../..') from tllib.modules import Classifier as ClassifierBase import tllib.vision.datasets as datasets import tllib.vision.models as models import tllib.normalization.ibn as ibn_models from tllib.vision.transforms import ResizeImage from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter def get_model_names(): return sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) + \ sorted(name for name in ibn_models.__dict__ if name.islower() and not name.startswith("__") and callable(ibn_models.__dict__[name])) + \ timm.list_models() def get_model(model_name): if model_name in models.__dict__: # load models from tllib.vision.models backbone = models.__dict__[model_name](pretrained=True) elif model_name in ibn_models.__dict__: # load models (with ibn) from tllib.normalization.ibn backbone = ibn_models.__dict__[model_name](pretrained=True) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=True) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() return backbone def get_dataset_names(): return sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) class ConcatDatasetWithDomainLabel(ConcatDataset): """ConcatDataset with domain label""" def __init__(self, *args, **kwargs): super(ConcatDatasetWithDomainLabel, self).__init__(*args, **kwargs) self.index_to_domain_id = {} domain_id = 0 start = 0 for end in self.cumulative_sizes: for idx in range(start, end): self.index_to_domain_id[idx] = domain_id start = end domain_id += 1 def __getitem__(self, index): img, target = super(ConcatDatasetWithDomainLabel, self).__getitem__(index) domain_id = self.index_to_domain_id[index] return img, target, domain_id def get_dataset(dataset_name, root, task_list, split='train', download=True, transform=None, seed=0): assert split in ['train', 'val', 'test'] # load datasets from tllib.vision.datasets # currently only PACS, OfficeHome and DomainNet are supported supported_dataset = ['PACS', 'OfficeHome', 'DomainNet'] assert dataset_name in supported_dataset dataset = datasets.__dict__[dataset_name] train_split_list = [] val_split_list = [] test_split_list = [] # we follow DomainBed and split each dataset randomly into two parts, with 80% samples and 20% samples # respectively, the former (larger) will be used as training set, and the latter will be used as validation set. split_ratio = 0.8 num_classes = 0 # under domain generalization setting, we use all samples in target domain as test set for task in task_list: if dataset_name == 'PACS': all_split = dataset(root=root, task=task, split='all', download=download, transform=transform) num_classes = all_split.num_classes elif dataset_name == 'OfficeHome': all_split = dataset(root=root, task=task, download=download, transform=transform) num_classes = all_split.num_classes elif dataset_name == 'DomainNet': train_split = dataset(root=root, task=task, split='train', download=download, transform=transform) test_split = dataset(root=root, task=task, split='test', download=download, transform=transform) num_classes = train_split.num_classes all_split = ConcatDataset([train_split, test_split]) train_split, val_split = split_dataset(all_split, int(len(all_split) * split_ratio), seed) train_split_list.append(train_split) val_split_list.append(val_split) test_split_list.append(all_split) train_dataset = ConcatDatasetWithDomainLabel(train_split_list) val_dataset = ConcatDatasetWithDomainLabel(val_split_list) test_dataset = ConcatDatasetWithDomainLabel(test_split_list) dataset_dict = { 'train': train_dataset, 'val': val_dataset, 'test': test_dataset } return dataset_dict[split], num_classes def split_dataset(dataset, n, seed=0): """ Return a pair of datasets corresponding to a random split of the given dataset, with n data points in the first dataset and the rest in the last, using the given random seed """ assert (n <= len(dataset)) idxes = list(range(len(dataset))) np.random.RandomState(seed).shuffle(idxes) subset_1 = idxes[:n] subset_2 = idxes[n:] return Subset(dataset, subset_1), Subset(dataset, subset_2) def validate(val_loader, model, args, device) -> float: batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target, _) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output = model(images) loss = F.cross_entropy(output, target) # measure accuracy and record loss acc1 = accuracy(output, target)[0] losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f} '.format(top1=top1)) return top1.avg def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=True, random_gray_scale=True): """ resizing mode: - default: random resized crop with scale factor(0.7, 1.0) and size 224; - cen.crop: take the center crop of 224; - res.|cen.crop: resize the image to 256 and take the center crop of size 224; - res: resize the image to 224; - res2x: resize the image to 448; - res.|crop: resize the image to 256 and take a random crop of size 224; - res.sma|crop: resize the image keeping its aspect ratio such that the smaller side is 256, then take a random crop of size 224; – inc.crop: “inception crop” from (Szegedy et al., 2015); – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224. """ if resizing == 'default': transform = T.RandomResizedCrop(224, scale=(0.7, 1.0)) elif resizing == 'cen.crop': transform = T.CenterCrop(224) elif resizing == 'res.|cen.crop': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224) ]) elif resizing == 'res': transform = ResizeImage(224) elif resizing == 'res2x': transform = ResizeImage(448) elif resizing == 'res.|crop': transform = T.Compose([ T.Resize((256, 256)), T.RandomCrop(224) ]) elif resizing == "res.sma|crop": transform = T.Compose([ T.Resize(256), T.RandomCrop(224) ]) elif resizing == 'inc.crop': transform = T.RandomResizedCrop(224) elif resizing == 'cif.crop': transform = T.Compose([ T.Resize((224, 224)), T.Pad(28), T.RandomCrop(224), ]) else: raise NotImplementedError(resizing) transforms = [transform] if random_horizontal_flip: transforms.append(T.RandomHorizontalFlip()) if random_color_jitter: transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)) if random_gray_scale: transforms.append(T.RandomGrayscale()) transforms.extend([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return T.Compose(transforms) def get_val_transform(resizing='default'): """ resizing mode: - default: resize the image to 224; - res2x: resize the image to 448; - res.|cen.crop: resize the image to 256 and take the center crop of size 224; """ if resizing == 'default': transform = ResizeImage(224) elif resizing == 'res2x': transform = ResizeImage(448) elif resizing == 'res.|cen.crop': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), ]) else: raise NotImplementedError(resizing) return T.Compose([ transform, T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def collect_feature(data_loader, feature_extractor: nn.Module, device: torch.device, max_num_features=None) -> torch.Tensor: """ Fetch data from `data_loader`, and then use `feature_extractor` to collect features. This function is specific for domain generalization because each element in data_loader is a tuple (images, labels, domain_labels). Args: data_loader (torch.utils.data.DataLoader): Data loader. feature_extractor (torch.nn.Module): A feature extractor. device (torch.device) max_num_features (int): The max number of features to return Returns: Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`). """ feature_extractor.eval() all_features = [] with torch.no_grad(): for i, (images, target, domain_labels) in enumerate(tqdm.tqdm(data_loader)): if max_num_features is not None and i >= max_num_features: break images = images.to(device) feature = feature_extractor(images).cpu() all_features.append(feature) return torch.cat(all_features, dim=0) class ImageClassifier(ClassifierBase): """ImageClassifier specific for reproducing results of `DomainBed `_. You are free to freeze all `BatchNorm2d` layers and insert one additional `Dropout` layer, this can achieve better results for some datasets like PACS but may be worse for others. Args: backbone (torch.nn.Module): Any backbone to extract features from data num_classes (int): Number of classes freeze_bn (bool, optional): whether to freeze all `BatchNorm2d` layers. Default: False dropout_p (float, optional): dropout ratio for additional `Dropout` layer, this layer is only used when `freeze_bn` is True. Default: 0.1 """ def __init__(self, backbone: nn.Module, num_classes: int, freeze_bn=False, dropout_p=0.1, **kwargs): super(ImageClassifier, self).__init__(backbone, num_classes, **kwargs) self.freeze_bn = freeze_bn if freeze_bn: self.feature_dropout = nn.Dropout(p=dropout_p) def forward(self, x: torch.Tensor): f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) if self.freeze_bn: f = self.feature_dropout(f) predictions = self.head(f) if self.training: return predictions, f else: return predictions def train(self, mode=True): super(ImageClassifier, self).train(mode) if self.freeze_bn: for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() class RandomDomainSampler(Sampler): r"""Randomly sample :math:`N` domains, then randomly select :math:`K` samples in each domain to form a mini-batch of size :math:`N\times K`. Args: data_source (ConcatDataset): dataset that contains data from multiple domains batch_size (int): mini-batch size (:math:`N\times K` here) n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here) """ def __init__(self, data_source: ConcatDataset, batch_size: int, n_domains_per_batch: int): super(Sampler, self).__init__() self.n_domains_in_dataset = len(data_source.cumulative_sizes) self.n_domains_per_batch = n_domains_per_batch assert self.n_domains_in_dataset >= self.n_domains_per_batch self.sample_idxes_per_domain = [] start = 0 for end in data_source.cumulative_sizes: idxes = [idx for idx in range(start, end)] self.sample_idxes_per_domain.append(idxes) start = end assert batch_size % n_domains_per_batch == 0 self.batch_size_per_domain = batch_size // n_domains_per_batch self.length = len(list(self.__iter__())) def __iter__(self): sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain) domain_idxes = [idx for idx in range(self.n_domains_in_dataset)] final_idxes = [] stop_flag = False while not stop_flag: selected_domains = random.sample(domain_idxes, self.n_domains_per_batch) for domain in selected_domains: sample_idxes = sample_idxes_per_domain[domain] if len(sample_idxes) < self.batch_size_per_domain: selected_idxes = np.random.choice(sample_idxes, self.batch_size_per_domain, replace=True) else: selected_idxes = random.sample(sample_idxes, self.batch_size_per_domain) final_idxes.extend(selected_idxes) for idx in selected_idxes: if idx in sample_idxes_per_domain[domain]: sample_idxes_per_domain[domain].remove(idx) remaining_size = len(sample_idxes_per_domain[domain]) if remaining_size < self.batch_size_per_domain: stop_flag = True return iter(final_idxes) def __len__(self): return self.length ================================================ FILE: examples/domain_generalization/image_classification/vrex.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.analysis import tsne, a_distance device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, random_color_jitter=True, random_gray_scale=True) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, num_classes = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='train', download=True, transform=train_transform, seed=args.seed) sampler = utils.RandomDomainSampler(train_dataset, args.batch_size, args.n_domains_per_batch) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, drop_last=True) val_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.sources, split='val', download=True, transform=val_transform, seed=args.seed) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) test_dataset, _ = utils.get_dataset(dataset_name=args.data, root=args.root, task_list=args.targets, split='test', download=True, transform=val_transform, seed=args.seed) test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("train_dataset_size: ", len(train_dataset)) print('val_dataset_size: ', len(val_dataset)) print("test_dataset_size: ", len(test_dataset)) train_iter = ForeverDataIterator(train_loader) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, freeze_bn=args.freeze_bn, dropout_p=args.dropout_p, finetune=args.finetune, pool_layer=pool_layer).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch) # for simplicity assert args.anneal_iters % args.iters_per_epoch == 0 # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # extract features from both domains feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) source_feature = utils.collect_feature(val_loader, feature_extractor, device, max_num_features=100) target_feature = utils.collect_feature(test_loader, feature_extractor, device, max_num_features=100) print(len(source_feature), len(target_feature)) # plot t-SNE tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png') tsne.visualize(source_feature, target_feature, tSNE_filename) print("Saving t-SNE to", tSNE_filename) # calculate A-distance, which is a measure for distribution discrepancy A_distance = a_distance.calculate(source_feature, target_feature, device) print("A-distance =", A_distance) return if args.phase == 'test': acc1 = utils.validate(test_loader, classifier, args, device) print(acc1) return # start training best_val_acc1 = 0. best_test_acc1 = 0. for epoch in range(args.epochs): if epoch * args.iters_per_epoch == args.anneal_iters: # reset optimizer to avoid sharp jump in gradient magnitudes optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = CosineAnnealingLR(optimizer, args.epochs * args.iters_per_epoch - args.anneal_iters) print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, lr_scheduler, args.n_domains_per_batch, epoch, args) # evaluate on validation set print("Evaluate on validation set...") acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_val_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_acc1 = max(acc1, best_val_acc1) # evaluate on test set print("Evaluate on test set...") best_test_acc1 = max(best_test_acc1, utils.validate(test_loader, classifier, args, device)) # evaluate on test set classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) acc1 = utils.validate(test_loader, classifier, args, device) print("test acc on test set = {}".format(acc1)) print("oracle acc on test set = {}".format(best_test_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model, optimizer, lr_scheduler: CosineAnnealingLR, n_domains_per_batch: int, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') losses_ce = AverageMeter('CELoss', ':3.2f') losses_penalty = AverageMeter('Penalty Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, losses_ce, losses_penalty, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x_all, labels_all, _ = next(train_iter) x_all = x_all.to(device) labels_all = labels_all.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_all, _ = model(x_all) loss_ce_per_domain = torch.zeros(n_domains_per_batch).to(device) for domain_id, (y_per_domain, labels_per_domain) in enumerate( zip(y_all.chunk(n_domains_per_batch, dim=0), labels_all.chunk(n_domains_per_batch, dim=0))): loss_ce_per_domain[domain_id] = F.cross_entropy(y_per_domain, labels_per_domain) # cls loss loss_ce = loss_ce_per_domain.mean() # penalty loss loss_penalty = ((loss_ce_per_domain - loss_ce) ** 2).mean() global_iter = epoch * args.iters_per_epoch + i if global_iter >= args.anneal_iters: trade_off = args.trade_off else: trade_off = 1 loss = loss_ce + loss_penalty * trade_off cls_acc = accuracy(y_all, labels_all)[0] losses.update(loss.item(), x_all.size(0)) losses_ce.update(loss_ce.item(), x_all.size(0)) losses_penalty.update(loss_penalty.item(), x_all.size(0)) cls_accs.update(cls_acc.item(), x_all.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='VREx for Domain Generalization') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', default='PACS', help='dataset: ' + ' | '.join(utils.get_dataset_names()) + ' (default: PACS)') parser.add_argument('-s', '--sources', nargs='+', default=None, help='source domain(s)') parser.add_argument('-t', '--targets', nargs='+', default=None, help='target domain(s)') parser.add_argument('--train-resizing', type=str, default='default') parser.add_argument('--val-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--freeze-bn', action='store_true', help='whether freeze all bn layers') parser.add_argument('--dropout-p', type=float, default=0.1, help='only activated when freeze-bn is True') # training parameters parser.add_argument('--trade-off', default=3, type=float, help='the trade off hyper parameter for vrex penalty') parser.add_argument('--anneal-iters', default=500, type=int, help='anneal iterations (trade off is set to 1 during these iterations)') parser.add_argument('-b', '--batch-size', default=36, type=int, metavar='N', help='mini-batch size (default: 36)') parser.add_argument('--n-domains-per-batch', default=3, type=int, help='number of domains in each mini-batch') parser.add_argument('--lr', '--learning-rate', default=5e-4, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='vrex', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/image_classification/vrex.sh ================================================ #!/usr/bin/env bash # ResNet50, PACS CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s A C S -t P -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_P CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P C S -t A -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_A CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P A S -t C -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_C CUDA_VISIBLE_DEVICES=0 python vrex.py data/PACS -d PACS -s P A C -t S -a resnet50 --freeze-bn --seed 0 --log logs/vrex/PACS_S # ResNet50, Office-Home CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Cl Rw -t Pr -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Pr CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Cl Pr -t Rw -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Rw CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Ar Rw Pr -t Cl -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Cl CUDA_VISIBLE_DEVICES=0 python vrex.py data/office-home -d OfficeHome -s Cl Rw Pr -t Ar -a resnet50 --seed 0 --log logs/vrex/OfficeHome_Ar # ResNet50, DomainNet CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s i p q r s -t c -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_c CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c p q r s -t i -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_i CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i q r s -t p -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_p CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p r s -t q -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_q CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p q s -t r -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_r CUDA_VISIBLE_DEVICES=0 python vrex.py data/domainnet -d DomainNet -s c i p q r -t s -a resnet50 -i 4000 --anneal-iters 4000 --lr 0.005 --trade-off 1 --seed 0 --log logs/vrex/DomainNet_s ================================================ FILE: examples/domain_generalization/re_identification/README.md ================================================ # Domain Generalization for Person Re-Identification ## Installation It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results. Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You also need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset Following datasets can be downloaded automatically: - [Market1501](http://zheng-lab.cecs.anu.edu.au/Project/project_reid.html) - [DukeMTMC](https://exposing.ai/duke_mtmc/) - [MSMT17](https://arxiv.org/pdf/1711.08565.pdf) ## Supported Methods Supported methods include: - [Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (IBN-Net, 2018 ECCV)](https://openaccess.thecvf.com/content_ECCV_2018/papers/Xingang_Pan_Two_at_Once_ECCV_2018_paper.pdf) - [Domain Generalization with MixStyle (MixStyle, 2021 ICLR)](https://arxiv.org/abs/2104.02008) ## Usage The shell files give the script to reproduce the benchmark with specified hyper-parameters. For example, if you want to train MixStyle on Market1501 -> DukeMTMC task, use the following script ```shell script # Train MixStyle on Market1501 -> DukeMTMC task using ResNet 50. # Assume you have put the datasets under the path `data/market1501` and `data/dukemtmc`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t DukeMTMC -a resnet50 \ --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2Duke ``` ### Experiment and Results In our experiments, we adopt modified resnet architecture from [MMT](https://arxiv.org/pdf/2001.01526.pdf>). For a fair comparison, we use standard cross entropy loss and triplet loss in all methods. **Notations** - ``Avg`` means the mAP (mean average precision) reported by `TLlib`. ### Cross dataset mAP on ResNet-50 | Methods | Avg | Market2Duke | Duke2Market | Market2MSMT | MSMT2Market | Duke2MSMT | MSMT2Duke | |----------|------|-------------|-------------|-------------|-------------|-----------|-----------| | Baseline | 23.5 | 25.6 | 29.6 | 6.3 | 31.7 | 10.1 | 37.8 | | IBN | 27.0 | 31.5 | 33.3 | 10.4 | 33.6 | 13.7 | 40.0 | | MixStyle | 25.5 | 27.2 | 31.6 | 8.2 | 33.9 | 12.4 | 39.9 | ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{IBN-Net, author = {Xingang Pan, Ping Luo, Jianping Shi, and Xiaoou Tang}, title = {Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net}, booktitle = {ECCV}, year = {2018} } @inproceedings{mixstyle, title={Domain Generalization with MixStyle}, author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao}, booktitle={ICLR}, year={2021} } ``` ================================================ FILE: examples/domain_generalization/re_identification/baseline.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import numpy as np import torch import torch.nn as nn from torch.nn import DataParallel import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.utils.data import DataLoader import utils from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss from tllib.vision.models.reid.identifier import ReIdentifier import tllib.vision.datasets.reid as datasets from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset from tllib.utils.scheduler import WarmupMultiStepLR from tllib.utils.metric.reid import validate, visualize_ranked_results from tllib.utils.data import ForeverDataIterator, RandomMultipleGallerySampler from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing, random_horizontal_flip=True, random_color_jitter=False, random_gray_scale=False) val_transform = utils.get_val_transform(args.height, args.width) print("train_transform: ", train_transform) print("val_transform: ", val_transform) working_dir = osp.dirname(osp.abspath(__file__)) root = osp.join(working_dir, args.root) # source dataset source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower())) sampler = RandomMultipleGallerySampler(source_dataset.train, args.num_instances) train_loader = DataLoader( convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform), batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True) train_iter = ForeverDataIterator(train_loader) val_loader = DataLoader( convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)), root=source_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # target dataset target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower())) test_loader = DataLoader( convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)), root=target_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # create model num_classes = source_dataset.num_train_pids backbone = utils.get_model(args.arch) pool_layer = nn.Identity() if args.no_pool else None model = ReIdentifier(backbone, num_classes, finetune=args.finetune, pool_layer=pool_layer).to(device) model = DataParallel(model) # define optimizer and learning rate scheduler optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr, weight_decay=args.weight_decay) lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1, warmup_steps=args.warmup_steps) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # plot t-SNE utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model, filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device) # visualize ranked results visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device, visualize_dir=logger.visualize_directory, width=args.width, height=args.height, rerank=args.rerank) return if args.phase == 'test': print("Test on source domain:") validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("Test on target domain:") validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) return # define loss function criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device) criterion_triplet = SoftTripletLoss(margin=args.margin).to(device) # start training best_val_mAP = 0. best_test_mAP = 0. for epoch in range(args.epochs): # print learning rate print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args) # update learning rate lr_scheduler.step() if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1): # evaluate on validation set print("Validation on source domain...") _, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True) # remember best mAP and save checkpoint torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if val_mAP > best_val_mAP: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_mAP = max(val_mAP, best_val_mAP) # evaluate on test set print("Test on target domain...") _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) best_test_mAP = max(test_mAP, best_test_mAP) # evaluate on test set model.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) print("Test on target domain:") _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("test mAP on target = {}".format(test_mAP)) print("oracle mAP on target = {}".format(best_test_mAP)) logger.close() def train(train_iter: ForeverDataIterator, model, criterion_ce: CrossEntropyLossWithLabelSmooth, criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_ce = AverageMeter('CeLoss', ':3.2f') losses_triplet = AverageMeter('TripletLoss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, _, labels, _ = next(train_iter) x = x.to(device) labels = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, f = model(x) # cross entropy loss loss_ce = criterion_ce(y, labels) # triplet loss loss_triplet = criterion_triplet(f, f, labels) loss = loss_ce + loss_triplet * args.trade_off cls_acc = accuracy(y, labels)[0] losses_ce.update(loss_ce.item(), x.size(0)) losses_triplet.update(loss_triplet.item(), x.size(0)) losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description="Baseline for Domain Generalizable ReID") # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-s', '--source', type=str, help='source domain') parser.add_argument('-t', '--target', type=str, help='target domain') parser.add_argument('--train-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='reid_resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: reid_resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--rate', type=float, default=0.2) # training parameters parser.add_argument('--trade-off', type=float, default=1, help='trade-off hyper parameter between cross entropy loss and triplet loss') parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard') parser.add_argument('-j', '--workers', type=int, default=4) parser.add_argument('-b', '--batch-size', type=int, default=16) parser.add_argument('--height', type=int, default=256, help="input height") parser.add_argument('--width', type=int, default=128, help="input width") parser.add_argument('--num-instances', type=int, default=4, help="each minibatch consist of " "(batch_size // num_instances) identities, and " "each identity has num_instances instances, " "default: 4") parser.add_argument('--lr', type=float, default=0.00035, help="initial learning rate") parser.add_argument('--weight-decay', type=float, default=5e-4) parser.add_argument('--epochs', type=int, default=80) parser.add_argument('--warmup-steps', type=int, default=10, help='number of warp-up steps') parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70], help='milestones for the learning rate decay') parser.add_argument('--eval-step', type=int, default=40) parser.add_argument('--iters-per-epoch', type=int, default=400) parser.add_argument('--print-freq', type=int, default=40) parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.') parser.add_argument('--rerank', action='store_true', help="evaluation only") parser.add_argument("--log", type=str, default='baseline', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/re_identification/baseline.sh ================================================ #!/usr/bin/env bash # Market1501 -> Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a reid_resnet50 \ --finetune --seed 0 --log logs/baseline/Market2Duke # Duke -> Market1501 CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a reid_resnet50 \ --finetune --seed 0 --log logs/baseline/Duke2Market # Market1501 -> MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a reid_resnet50 \ --finetune --seed 0 --log logs/baseline/Market2MSMT # MSMT -> Market1501 CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a reid_resnet50 \ --finetune --seed 0 --log logs/baseline/MSMT2Market # Duke -> MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a reid_resnet50 \ --finetune --seed 0 --log logs/baseline/Duke2MSMT # MSMT -> Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a reid_resnet50 \ --finetune --seed 0 --log logs/baseline/MSMT2Duke ================================================ FILE: examples/domain_generalization/re_identification/ibn.sh ================================================ #!/usr/bin/env bash # Market1501 -> Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a resnet50_ibn_a \ --finetune --seed 0 --log logs/ibn/Market2Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t DukeMTMC -a resnet50_ibn_b \ --finetune --seed 0 --log logs/ibn/Market2Duke # Duke -> Market1501 CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a resnet50_ibn_a \ --finetune --seed 0 --log logs/ibn/Duke2Market CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t Market1501 -a resnet50_ibn_b \ --finetune --seed 0 --log logs/ibn/Duke2Market # Market1501 -> MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a resnet50_ibn_a \ --finetune --seed 0 --log logs/ibn/Market2MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data -s Market1501 -t MSMT17 -a resnet50_ibn_b \ --finetune --seed 0 --log logs/ibn/Market2MSMT # MSMT -> Market1501 CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a resnet50_ibn_a \ --finetune --seed 0 --log logs/ibn/MSMT2Market CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t Market1501 -a resnet50_ibn_b \ --finetune --seed 0 --log logs/ibn/MSMT2Market # Duke -> MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a resnet50_ibn_a \ --finetune --seed 0 --log logs/ibn/Duke2MSMT CUDA_VISIBLE_DEVICES=0 python baseline.py data -s DukeMTMC -t MSMT17 -a resnet50_ibn_b \ --finetune --seed 0 --log logs/ibn/Duke2MSMT # MSMT -> Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a resnet50_ibn_a \ --finetune --seed 0 --log logs/ibn/MSMT2Duke CUDA_VISIBLE_DEVICES=0 python baseline.py data -s MSMT17 -t DukeMTMC -a resnet50_ibn_b \ --finetune --seed 0 --log logs/ibn/MSMT2Duke ================================================ FILE: examples/domain_generalization/re_identification/mixstyle.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import os.path as osp import numpy as np import torch from torch.nn import DataParallel import torch.backends.cudnn as cudnn from torch.optim import Adam from torch.utils.data import DataLoader import utils from tllib.normalization.mixstyle.sampler import RandomDomainMultiInstanceSampler import tllib.normalization.mixstyle.resnet as models from tllib.vision.models.reid.identifier import ReIdentifier from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss import tllib.vision.datasets.reid as datasets from tllib.vision.datasets.reid.convert import convert_to_pytorch_dataset from tllib.vision.models.reid.resnet import ReidResNet from tllib.utils.scheduler import WarmupMultiStepLR from tllib.utils.metric.reid import validate, visualize_ranked_results from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.height, args.width, args.train_resizing, random_horizontal_flip=True, random_color_jitter=False, random_gray_scale=False) val_transform = utils.get_val_transform(args.height, args.width) print("train_transform: ", train_transform) print("val_transform: ", val_transform) working_dir = osp.dirname(osp.abspath(__file__)) root = osp.join(working_dir, args.root) # source dataset source_dataset = datasets.__dict__[args.source](root=osp.join(root, args.source.lower())) sampler = RandomDomainMultiInstanceSampler(source_dataset.train, batch_size=args.batch_size, n_domains_per_batch=2, num_instances=args.num_instances) train_loader = DataLoader( convert_to_pytorch_dataset(source_dataset.train, root=source_dataset.images_dir, transform=train_transform), batch_size=args.batch_size, num_workers=args.workers, sampler=sampler, pin_memory=True, drop_last=True) train_iter = ForeverDataIterator(train_loader) val_loader = DataLoader( convert_to_pytorch_dataset(list(set(source_dataset.query) | set(source_dataset.gallery)), root=source_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # target dataset target_dataset = datasets.__dict__[args.target](root=osp.join(root, args.target.lower())) test_loader = DataLoader( convert_to_pytorch_dataset(list(set(target_dataset.query) | set(target_dataset.gallery)), root=target_dataset.images_dir, transform=val_transform), batch_size=args.batch_size, num_workers=args.workers, shuffle=False, pin_memory=True) # create model num_classes = source_dataset.num_train_pids backbone = models.__dict__[args.arch](mix_layers=args.mix_layers, mix_p=args.mix_p, mix_alpha=args.mix_alpha, resnet_class=ReidResNet, pretrained=True) model = ReIdentifier(backbone, num_classes, finetune=args.finetune).to(device) model = DataParallel(model) # define optimizer and learning rate scheduler optimizer = Adam(model.module.get_parameters(base_lr=args.lr, rate=args.rate), args.lr, weight_decay=args.weight_decay) lr_scheduler = WarmupMultiStepLR(optimizer, args.milestones, gamma=0.1, warmup_factor=0.1, warmup_steps=args.warmup_steps) # resume from the best checkpoint if args.phase != 'train': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') model.load_state_dict(checkpoint) # analysis the model if args.phase == 'analysis': # plot t-SNE utils.visualize_tsne(source_loader=val_loader, target_loader=test_loader, model=model, filename=osp.join(logger.visualize_directory, 'analysis', 'TSNE.pdf'), device=device) # visualize ranked results visualize_ranked_results(test_loader, model, target_dataset.query, target_dataset.gallery, device, visualize_dir=logger.visualize_directory, width=args.width, height=args.height, rerank=args.rerank) return if args.phase == 'test': print("Test on source domain:") validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("Test on target domain:") validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) return # define loss function criterion_ce = CrossEntropyLossWithLabelSmooth(num_classes).to(device) criterion_triplet = SoftTripletLoss(margin=args.margin).to(device) # start training best_val_mAP = 0. best_test_mAP = 0. for epoch in range(args.epochs): # print learning rate print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, model, criterion_ce, criterion_triplet, optimizer, epoch, args) # update learning rate lr_scheduler.step() if (epoch + 1) % args.eval_step == 0 or (epoch == args.epochs - 1): # evaluate on validation set print("Validation on source domain...") _, val_mAP = validate(val_loader, model, source_dataset.query, source_dataset.gallery, device, cmc_flag=True) # remember best mAP and save checkpoint torch.save(model.state_dict(), logger.get_checkpoint_path('latest')) if val_mAP > best_val_mAP: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_val_mAP = max(val_mAP, best_val_mAP) # evaluate on test set print("Test on target domain...") _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) best_test_mAP = max(test_mAP, best_test_mAP) # evaluate on test set model.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) print("Test on target domain:") _, test_mAP = validate(test_loader, model, target_dataset.query, target_dataset.gallery, device, cmc_flag=True, rerank=args.rerank) print("test mAP on target = {}".format(test_mAP)) print("oracle mAP on target = {}".format(best_test_mAP)) logger.close() def train(train_iter: ForeverDataIterator, model, criterion_ce: CrossEntropyLossWithLabelSmooth, criterion_triplet: SoftTripletLoss, optimizer: Adam, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses_ce = AverageMeter('CeLoss', ':3.2f') losses_triplet = AverageMeter('TripletLoss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses_ce, losses_triplet, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, _, labels, _ = next(train_iter) x = x.to(device) labels = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, f = model(x) # cross entropy loss loss_ce = criterion_ce(y, labels) # triplet loss loss_triplet = criterion_triplet(f, f, labels) loss = loss_ce + loss_triplet * args.trade_off cls_acc = accuracy(y, labels)[0] losses_ce.update(loss_ce.item(), x.size(0)) losses_triplet.update(loss_triplet.item(), x.size(0)) losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': architecture_names = sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) dataset_names = sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) parser = argparse.ArgumentParser(description="MixStyle for Domain Generalizable ReID") # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-s', '--source', type=str, help='source domain') parser.add_argument('-t', '--target', type=str, help='target domain') parser.add_argument('--train-resizing', type=str, default='default') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=architecture_names, help='backbone architecture: ' + ' | '.join(architecture_names) + ' (default: resnet50)') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--rate', type=float, default=0.2) parser.add_argument('--mix-layers', nargs='+', help='layers to apply MixStyle') parser.add_argument('--mix-p', default=0.5, type=float, help='probability to apply MixStyle') parser.add_argument('--mix-alpha', default=0.1, type=float, help='parameter alpha for beta distribution') # training parameters parser.add_argument('--trade-off', type=float, default=1, help='trade-off hyper parameter between cross entropy loss and triplet loss') parser.add_argument('--margin', type=float, default=0.0, help='margin for the triplet loss with batch hard') parser.add_argument('-j', '--workers', type=int, default=4) parser.add_argument('-b', '--batch-size', type=int, default=16) parser.add_argument('--height', type=int, default=256, help="input height") parser.add_argument('--width', type=int, default=128, help="input width") parser.add_argument('--num-instances', type=int, default=4, help="each minibatch consist of " "(batch_size // num_instances) identities, and " "each identity has num_instances instances, " "default: 4") parser.add_argument('--lr', type=float, default=0.00035, help="learning rate of new parameters, for pretrained ") parser.add_argument('--weight-decay', type=float, default=5e-4) parser.add_argument('--epochs', type=int, default=80) parser.add_argument('--warmup-steps', type=int, default=10, help='number of warm-up steps') parser.add_argument('--milestones', nargs='+', type=int, default=[40, 70], help='milestones for the learning rate decay') parser.add_argument('--eval-step', type=int, default=40) parser.add_argument('--iters-per-epoch', type=int, default=400) parser.add_argument('--print-freq', type=int, default=40) parser.add_argument('--seed', default=None, type=int, help='seed for initializing training.') parser.add_argument('--rerank', action='store_true', help="evaluation only") parser.add_argument("--log", type=str, default='mixstyle', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/domain_generalization/re_identification/mixstyle.sh ================================================ #!/usr/bin/env bash # Market1501 -> Duke CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t DukeMTMC -a resnet50 \ --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2Duke # Duke -> Market1501 CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s DukeMTMC -t Market1501 -a resnet50 \ --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Duke2Market # Market1501 -> MSMT CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s Market1501 -t MSMT17 -a resnet50 \ --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Market2MSMT # MSMT -> Market1501 CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s MSMT17 -t Market1501 -a resnet50 \ --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/MSMT2Market # Duke -> MSMT CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s DukeMTMC -t MSMT17 -a resnet50 \ --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/Duke2MSMT # MSMT -> Duke CUDA_VISIBLE_DEVICES=0 python mixstyle.py data -s MSMT17 -t DukeMTMC -a resnet50 \ --mix-layers layer1 layer2 --finetune --seed 0 --log logs/mixstyle/MSMT2Duke ================================================ FILE: examples/domain_generalization/re_identification/requirements.txt ================================================ timm opencv-python ================================================ FILE: examples/domain_generalization/re_identification/utils.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import sys import timm import torch import torch.nn as nn import torchvision.transforms as T sys.path.append('../../..') from tllib.utils.metric.reid import extract_reid_feature from tllib.utils.analysis import tsne import tllib.vision.models.reid as models import tllib.normalization.ibn as ibn_models def get_model_names(): return sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) + \ sorted(name for name in ibn_models.__dict__ if name.islower() and not name.startswith("__") and callable(ibn_models.__dict__[name])) + \ timm.list_models() def get_model(model_name): if model_name in models.__dict__: # load models from tllib.vision.models backbone = models.__dict__[model_name](pretrained=True) elif model_name in ibn_models.__dict__: # load models (with ibn) from tllib.normalization.ibn backbone = ibn_models.__dict__[model_name](pretrained=True) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=True) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() return backbone def get_train_transform(height, width, resizing='default', random_horizontal_flip=True, random_color_jitter=False, random_gray_scale=False): """ resizing mode: - default: resize the image to (height, width), zero-pad it by 10 on each size, the take a random crop of (height, width) - res: resize the image to(height, width) """ if resizing == 'default': transform = T.Compose([ T.Resize((height, width), interpolation=3), T.Pad(10), T.RandomCrop((height, width)) ]) elif resizing == 'res': transform = T.Resize((height, width), interpolation=3) else: raise NotImplementedError(resizing) transforms = [transform] if random_horizontal_flip: transforms.append(T.RandomHorizontalFlip()) if random_color_jitter: transforms.append(T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3)) if random_gray_scale: transforms.append(T.RandomGrayscale()) transforms.extend([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return T.Compose(transforms) def get_val_transform(height, width): return T.Compose([ T.Resize((height, width), interpolation=3), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def visualize_tsne(source_loader, target_loader, model, filename, device, n_data_points_per_domain=3000): """Visualize features from different domains using t-SNE. As we can have very large number of samples in each domain, only `n_data_points_per_domain` number of samples are randomly selected in each domain. """ source_feature_dict = extract_reid_feature(source_loader, model, device, normalize=True) source_feature = torch.stack(list(source_feature_dict.values())).cpu() source_feature = source_feature[torch.randperm(len(source_feature))] source_feature = source_feature[:n_data_points_per_domain] target_feature_dict = extract_reid_feature(target_loader, model, device, normalize=True) target_feature = torch.stack(list(target_feature_dict.values())).cpu() target_feature = target_feature[torch.randperm(len(target_feature))] target_feature = target_feature[:n_data_points_per_domain] tsne.visualize(source_feature, target_feature, filename, source_color='cornflowerblue', target_color='darkorange') print('T-SNE process is done, figure is saved to {}'.format(filename)) ================================================ FILE: examples/model_selection/README.md ================================================ # Model Selection ## Installation Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset - [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) - [Caltech101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) - [CIFAR10](http://www.cs.utoronto.ca/~kriz/cifar.html) - [CIFAR100](http://www.cs.utoronto.ca/~kriz/cifar.html) - [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html) - [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/) - [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) - [SUN397](https://vision.princeton.edu/projects/2010/SUN/) ## Supported Methods Supported methods include: - [An Information-theoretic Approach to Transferability in Task Transfer Learning (H-Score, ICIP 2019)](http://yangli-feasibility.com/home/media/icip-19.pdf) - [LEEP: A New Measure to Evaluate Transferability of Learned Representations (LEEP, ICML 2020)](http://proceedings.mlr.press/v119/nguyen20b/nguyen20b.pdf) - [Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models for Transfer Learning (LogME, ICML 2021)](https://arxiv.org/pdf/2102.11005.pdf) - [Negative Conditional Entropy in `Transferability and Hardness of Supervised Classification Tasks (NCE, ICCV 2019)](https://arxiv.org/pdf/1908.08142v1.pdf) ## Experiment and Results ### Model Ranking on image classification tasks The shell files give the scripts to ranking pre-trained models on a given dataset. For example, if you want to use LogME to calculate the transfer performance of ResNet50(ImageNet pre-trained) on Aircraft, use the following script ```shell script # Using LogME to ranking pre-trained ResNet50 on Aircraft # Assume you have put the datasets under the path `data/cub200`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features ``` We use LEEP, NCE HScore and LogME to compute scores by applying 10 pre-trained models to different datasets. The correlation([Weighted kendall Tau](https://vigna.di.unimi.it/ftp/papers/WeightedTau.pdf)/Pearson Correlation) between scores and fine-tuned accuracies are presented. #### Model Ranking Benchmark on Aircraft | Model | Finetuned Acc | HScore | LEEP | LogME | NCE | |--------------|---------------|--------|--------|-------|--------| | GoogleNet | 82.7 | 28.37 | -4.310 | 0.934 | -4.248 | | Inception V3 | 88.8 | 43.89 | -4.202 | 0.953 | -4.170 | | ResNet50 | 86.6 | 46.23 | -4.215 | 0.946 | -4.201 | | ResNet101 | 85.6 | 46.13 | -4.230 | 0.948 | -4.222 | | ResNet152 | 85.3 | 46.25 | -4.230 | 0.950 | -4.229 | | DenseNet121 | 85.4 | 31.53 | -4.228 | 0.938 | -4.215 | | DenseNet169 | 84.5 | 41.81 | -4.245 | 0.943 | -4.270 | | Densenet201 | 84.6 | 46.01 | -4.206 | 0.942 | -4.189 | | MobileNet V2 | 82.8 | 34.43 | -4.198 | 0.941 | -4.208 | | MNasNet | 72.8 | 35.28 | -4.192 | 0.948 | -4.195 | | Pearson Corr | - | 0.688 | 0.127 | 0.582 | 0.173 | | Weighted Tau | - | 0.664 | -0.264 | 0.595 | 0.002 | #### Model Ranking Benchmark on Caltech101 | Model | Finetuned Acc | HScore | LEEP | LogME | NCE | |--------------|---------------|--------|--------|-------|--------| | GoogleNet | 91.7 | 75.88 | -1.462 | 1.228 | -0.665 | | Inception V3 | 94.3 | 93.73 | -1.119 | 1.387 | -0.560 | | ResNet50 | 91.8 | 91.65 | -1.020 | 1.262 | -0.616 | | ResNet101 | 93.1 | 92.54 | -0.899 | 1.305 | -0.603 | | ResNet152 | 93.2 | 92.91 | -0.875 | 1.324 | -0.605 | | DenseNet121 | 91.9 | 75.02 | -0.979 | 1.172 | -0.609 | | DenseNet169 | 92.5 | 86.37 | -0.864 | 1.212 | -0.580 | | Densenet201 | 93.4 | 89.90 | -0.914 | 1.228 | -0.590 | | MobileNet V2 | 89.1 | 75.82 | -1.115 | 1.150 | -0.693 | | MNasNet | 91.5 | 77.00 | -1.043 | 1.178 | -0.690 | | Pearson Corr | - | 0.748 | 0.324 | 0.794 | 0.843 | | Weighted Tau | - | 0.721 | 0.127 | 0.697 | 0.810 | #### Model Ranking Benchmark on CIFAR10 | Model | Finetuned Acc | HScore | LEEP | LogME | NCE | |--------------|---------------|--------|--------|-------|--------| | GoogleNet | 96.2 | 5.911 | -1.385 | 0.293 | -1.139 | | Inception V3 | 97.5 | 6.363 | -1.259 | 0.349 | -1.060 | | ResNet50 | 96.8 | 6.567 | -1.010 | 0.388 | -1.007 | | ResNet101 | 97.7 | 6.901 | -0.829 | 0.463 | -0.838 | | ResNet152 | 97.9 | 6.945 | -0.838 | 0.469 | -0.851 | | DenseNet121 | 97.2 | 6.210 | -1.035 | 0.302 | -1.006 | | DenseNet169 | 97.4 | 6.547 | -0.934 | 0.343 | -0.946 | | Densenet201 | 97.4 | 6.706 | -0.888 | 0.369 | -0.866 | | MobileNet V2 | 95.7 | 5.928 | -1.100 | 0.291 | -1.089 | | MNasNet | 96.8 | 6.018 | -1.066 | 0.304 | -1.086 | | Pearson Corr | - | 0.839 | 0.604 | 0.733 | 0.786 | | Weighted Tau | - | 0.800 | 0.638 | 0.785 | 0.714 | #### Model Ranking Benchmark on CIFAR100 | Model | Finetuned Acc | HScore | LEEP | LogME | NCE | |--------------|---------------|--------|--------|-------|--------| | GoogleNet | 83.2 | 29.33 | -3.234 | 1.037 | -2.751 | | Inception V3 | 86.6 | 36.47 | -2.995 | 1.070 | -2.615 | | ResNet50 | 84.5 | 40.20 | -2.612 | 1.099 | -2.516 | | ResNet101 | 87.0 | 43.80 | -2.365 | 1.130 | -2.285 | | ResNet152 | 87.6 | 44.19 | -2.410 | 1.133 | -2.369 | | DenseNet121 | 84.8 | 32.13 | -2.665 | 1.029 | -2.504 | | DenseNet169 | 85.0 | 37.51 | -2.494 | 1.051 | -2.418 | | Densenet201 | 86.0 | 39.75 | -2.470 | 1.061 | -2.305 | | MobileNet V2 | 80.8 | 30.36 | -2.800 | 1.039 | -2.653 | | MNasNet | 83.9 | 32.05 | -2.732 | 1.051 | -2.643 | | Pearson Corr | - | 0.815 | 0.513 | 0.698 | 0.705 | | Weighted Tau | - | 0.775 | 0.659 | 0.790 | 0.654 | #### Model Ranking Benchmark on DTD | Model | Finetuned Acc | HScore | LEEP | LogME | NCE | |--------------|---------------|--------|--------|-------|-------| | GoogleNet | 73.6 | 34.61 | -2.333 | 0.682 | 0.682 | | Inception V3 | 77.2 | 57.17 | -2.135 | 0.691 | 0.691 | | ResNet50 | 75.2 | 78.26 | -1.985 | 0.695 | 0.695 | | ResNet101 | 76.2 | 117.23 | -1.974 | 0.689 | 0.689 | | ResNet152 | 75.4 | 32.30 | -1.924 | 0.698 | 0.698 | | DenseNet121 | 74.9 | 35.23 | -2.001 | 0.670 | 0.670 | | DenseNet169 | 74.8 | 43.36 | -1.817 | 0.686 | 0.686 | | Densenet201 | 74.5 | 45.96 | -1.926 | 0.689 | 0.689 | | MobileNet V2 | 72.9 | 37.99 | -2.098 | 0.664 | 0.664 | | MNasNet | 72.8 | 38.03 | -2.033 | 0.679 | 0.679 | | Pearson Corr | - | 0.532 | 0.217 | 0.617 | 0.471 | | Weighted Tau | - | 0.416 | -0.004 | 0.550 | 0.083 | #### Model Ranking Benchmark on OxfordIIITPets | Model | Finetuned Acc | HScore | LEEP | LogME | NCE | |--------------|---------------|--------|--------|-------|--------| | GoogleNet | 91.9 | 28.02 | -1.064 | 0.854 | -0.815 | | Inception V3 | 93.5 | 33.29 | -0.888 | 1.119 | -0.711 | | ResNet50 | 92.5 | 32.55 | -0.805 | 0.952 | -0.721 | | ResNet101 | 94.0 | 32.76 | -0.769 | 0.985 | -0.717 | | ResNet152 | 94.5 | 32.86 | -0.732 | 1.009 | -0.679 | | DenseNet121 | 92.9 | 27.09 | -0.837 | 0.797 | -0.753 | | DenseNet169 | 93.1 | 30.09 | -0.779 | 0.829 | -0.699 | | Densenet201 | 92.8 | 31.25 | -0.810 | 0.860 | -0.716 | | MobileNet V2 | 90.5 | 27.83 | -0.902 | 0.765 | -0.822 | | MNasNet | 89.4 | 27.95 | -0.854 | 0.785 | -0.812 | | Pearson Corr | - | 0.427 | -0.127 | 0.589 | 0.501 | | Weighted Tau | - | 0.425 | -0.143 | 0.502 | 0.119 | #### Model Ranking Benchmark on StanfordCars | Model | Finetuned Acc | HScore | LEEP | LogME | NCE | |--------------|---------------|--------|--------|-------|--------| | GoogleNet | 91.0 | 41.47 | -4.612 | 1.246 | -4.312 | | Inception V3 | 92.3 | 73.68 | -4.268 | 1.259 | -4.110 | | ResNet50 | 91.7 | 72.94 | -4.366 | 1.253 | -4.221 | | ResNet101 | 91.7 | 73.98 | -4.281 | 1.255 | -4.218 | | ResNet152 | 92.0 | 76.17 | -4.215 | 1.260 | -4.142 | | DenseNet121 | 91.5 | 45.82 | -4.437 | 1.249 | -4.271 | | DenseNet169 | 91.5 | 63.40 | -4.286 | 1.252 | -4.175 | | Densenet201 | 91.0 | 70.50 | -4.319 | 1.251 | -4.151 | | MobileNet V2 | 91.0 | 51.12 | -4.463 | 1.250 | -4.306 | | MNasNet | 88.5 | 51.91 | -4.423 | 1.254 | -4.338 | | Pearson Corr | - | 0.503 | 0.433 | 0.274 | 0.695 | | Weighted Tau | - | 0.638 | 0.703 | 0.654 | 0.750 | #### Model Ranking Benchmark on SUN397 | Model | Finetuned Acc | HScore | LEEP | LogME | NCE | |--------------|---------------|--------|--------|-------|--------| | GoogleNet | 62.0 | 71.35 | -3.744 | 1.621 | -3.055 | | Inception V3 | 65.7 | 114.21 | -3.372 | 1.648 | -2.844 | | ResNet50 | 64.7 | 110.39 | -3.198 | 1.638 | -2.894 | | ResNet101 | 64.8 | 113.63 | -3.103 | 1.642 | -2.837 | | ResNet152 | 66.0 | 116.51 | -3.056 | 1.646 | -2.822 | | DenseNet121 | 62.3 | 72.16 | -3.311 | 1.614 | -2.945 | | DenseNet169 | 63.0 | 95.80 | -3.165 | 1.623 | -2.903 | | Densenet201 | 64.7 | 103.09 | -3.205 | 1.624 | -2.896 | | MobileNet V2 | 60.5 | 75.90 | -3.338 | 1.617 | -2.968 | | MNasNet | 60.7 | 80.91 | -3.234 | 1.625 | -2.933 | | Pearson Corr | - | 0.913 | 0.428 | 0.824 | 0.782 | | Weighted Tau | - | 0.918 | 0.581 | 0.748 | 0.873 | ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{bao_information-theoretic_2019, title = {An Information-Theoretic Approach to Transferability in Task Transfer Learning}, booktitle = {ICIP}, author = {Bao, Yajie and Li, Yang and Huang, Shao-Lun and Zhang, Lin and Zheng, Lizhong and Zamir, Amir and Guibas, Leonidas}, year = {2019} } @inproceedings{nguyen_leep:_2020, title = {LEEP: A New Measure to Evaluate Transferability of Learned Representations}, booktitle = {ICML}, author = {Nguyen, Cuong and Hassner, Tal and Seeger, Matthias and Archambeau, Cedric}, year = {2020} } @inproceedings{you_logme:_2021, title = {LogME: Practical Assessment of Pre-trained Models for Transfer Learning}, booktitle = {ICML}, author = {You, Kaichao and Liu, Yong and Wang, Jianmin and Long, Mingsheng}, year = {2021} } @inproceedings{tran_transferability_2019, title = {Transferability and hardness of supervised classification tasks}, booktitle = {ICCV}, author = {Tran, Anh T. and Nguyen, Cuong V. and Hassner, Tal}, year = {2019} } ``` ================================================ FILE: examples/model_selection/hscore.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import os import sys import argparse import numpy as np import torch from torch.utils.data import DataLoader sys.path.append('../..') from tllib.ranking import h_score sys.path.append('.') import utils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args): logger = utils.Logger(args.data, args.arch, 'results_hscore') print(args) print(f'Calc Transferabilities of {args.arch} on {args.data}') try: features = np.load(os.path.join(logger.get_save_dir(), 'features.npy')) predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy')) targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy')) print('Loaded extracted features') except: print('Conducting feature extraction') data_transform = utils.get_transform(resizing=args.resizing) print("data_transform: ", data_transform) model = utils.get_model(args.arch, args.pretrained).to(device) score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate, args.num_samples_per_classes) score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print(f'Using {len(score_dataset)} samples for ranking') features, predictions, targets = utils.forwarding_dataset(score_loader, model, layer=eval(f'model.{args.layer}'), device=device) if args.save_features: np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features) np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions) np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets) print('Conducting transferability calculation') result = h_score(features, targets) logger.write( f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\n') print(f'Results saved in {logger.get_result_dir()}') logger.close() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Ranking pre-trained models with HScore') # dataset parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--resizing', default='res.', type=str) # model parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='model to be ranked: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('-l', '--layer', default='fc', help='before which layer features are extracted') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument("--save_features", action='store_true', help="whether to save extracted features") args = parser.parse_args() main(args) ================================================ FILE: examples/model_selection/hscore.sh ================================================ #!/usr/bin/env bash # Ranking Pre-trained Model # ====================================================================================================================== # CIFAR10 CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # CIFAR100 CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # FGVCAircraft CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # Caltech101 CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # Oxford-IIIT CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python hscore.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features ================================================ FILE: examples/model_selection/leep.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import os import sys import argparse import numpy as np import torch from torch.utils.data import DataLoader sys.path.append('../..') from tllib.ranking import log_expected_empirical_prediction as leep sys.path.append('.') import utils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args): logger = utils.Logger(args.data, args.arch, 'results_leep') print(args) print(f'Calc Transferabilities of {args.arch} on {args.data}') try: features = np.load(os.path.join(logger.get_save_dir(), 'features.npy')) predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy')) targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy')) print('Loaded extracted features') except: print('Conducting feature extraction') data_transform = utils.get_transform(resizing=args.resizing) print("data_transform: ", data_transform) model = utils.get_model(args.arch, args.pretrained).to(device) score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate, args.num_samples_per_classes) score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print(f'Using {len(score_dataset)} samples for ranking') features, predictions, targets = utils.forwarding_dataset(score_loader, model, layer=eval(f'model.{args.layer}'), device=device) if args.save_features: np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features) np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions) np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets) print('Conducting transferability calculation') result = leep(predictions, targets) logger.write( f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\n') logger.close() if __name__ == '__main__': parser = argparse.ArgumentParser( description='Ranking pre-trained models with LEEP (Log Expected Empirical Prediction)') # dataset parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--resizing', default='res.', type=str) # model parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='model to be ranked: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('-l', '--layer', default='fc', help='before which layer features are extracted') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument("--save_features", action='store_true', help="whether to save extracted features") args = parser.parse_args() main(args) ================================================ FILE: examples/model_selection/leep.sh ================================================ #!/usr/bin/env bash # Ranking Pre-trained Model # ====================================================================================================================== # CIFAR10 CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # CIFAR100 CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # FGVCAircraft CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # Caltech101 CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # Oxford-IIIT CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python leep.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features ================================================ FILE: examples/model_selection/logme.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import os import sys import argparse import numpy as np import torch from torch.utils.data import DataLoader sys.path.append('../..') from tllib.ranking import log_maximum_evidence as logme sys.path.append('.') import utils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args): logger = utils.Logger(args.data, args.arch, 'results_logme') print(args) print(f'Calc Transferabilities of {args.arch} on {args.data}') try: features = np.load(os.path.join(logger.get_save_dir(), 'features.npy')) predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy')) targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy')) print('Loaded extracted features') except: print('Conducting feature extraction') data_transform = utils.get_transform(resizing=args.resizing) print("data_transform: ", data_transform) model = utils.get_model(args.arch, args.pretrained).to(device) score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate, args.num_samples_per_classes) score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print(f'Using {len(score_dataset)} samples for ranking') features, predictions, targets = utils.forwarding_dataset(score_loader, model, layer=eval(f'model.{args.layer}'), device=device) if args.save_features: np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features) np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions) np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets) print('Conducting transferability calculation') result = logme(features, targets) logger.write( f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\n') logger.close() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Ranking pre-trained models with LogME (Log Maximum Evidence)') # dataset parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--resizing', default='res.', type=str) # model parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='model to be ranked: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('-l', '--layer', default='fc', help='before which layer features are extracted') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument("--save_features", action='store_true', help="whether to save extracted features") args = parser.parse_args() main(args) ================================================ FILE: examples/model_selection/logme.sh ================================================ #!/usr/bin/env bash # Ranking Pre-trained Model # ====================================================================================================================== # CIFAR10 CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # CIFAR100 CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # FGVCAircraft CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # Caltech101 CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # Oxford-IIIT CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python logme.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features ================================================ FILE: examples/model_selection/nce.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import os import sys import argparse import numpy as np import torch from torch.utils.data import DataLoader sys.path.append('../..') from tllib.ranking import negative_conditional_entropy as nce sys.path.append('.') import utils device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args): logger = utils.Logger(args.data, args.arch, 'results_nce') print(args) print(f'Calc Transferabilities of {args.arch} on {args.data}') try: features = np.load(os.path.join(logger.get_save_dir(), 'features.npy')) predictions = np.load(os.path.join(logger.get_save_dir(), 'preds.npy')) targets = np.load(os.path.join(logger.get_save_dir(), 'targets.npy')) print('Loaded extracted features') except: print('Conducting feature extraction') data_transform = utils.get_transform(resizing=args.resizing) print("data_transform: ", data_transform) model = utils.get_model(args.arch, args.pretrained).to(device) score_dataset, num_classes = utils.get_dataset(args.data, args.root, data_transform, args.sample_rate, args.num_samples_per_classes) score_loader = DataLoader(score_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) print(f'Using {len(score_dataset)} samples for ranking') features, predictions, targets = utils.forwarding_dataset(score_loader, model, layer=eval(f'model.{args.layer}'), device=device) if args.save_features: np.save(os.path.join(logger.get_save_dir(), 'features.npy'), features) np.save(os.path.join(logger.get_save_dir(), 'preds.npy'), predictions) np.save(os.path.join(logger.get_save_dir(), 'targets.npy'), targets) print('Conducting transferability calculation') result = nce(np.argmax(predictions, axis=1), targets) logger.write( f'# {result:.4f} # data_{args.data}_sr{args.sample_rate}_sc{args.num_samples_per_classes}_model_{args.arch}_layer_{args.layer}\n') logger.close() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Ranking pre-trained models with NCE (Negative Conditional Entropy)') # dataset parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--resizing', default='res.', type=str) # model parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='model to be ranked: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('-l', '--layer', default='fc', help='before which layer features are extracted') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument("--save_features", action='store_true', help="whether to save extracted features") args = parser.parse_args() main(args) ================================================ FILE: examples/model_selection/nce.sh ================================================ #!/usr/bin/env bash # Ranking Pre-trained Model # ====================================================================================================================== # CIFAR10 CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar10 -d CIFAR10 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # CIFAR100 CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/cifar100 -d CIFAR100 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # FGVCAircraft CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/FGVCAircraft -d Aircraft -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # Caltech101 CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/caltech101 -d Caltech101 -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/dtd -d DTD -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # Oxford-IIIT CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/Oxford-IIIT -d OxfordIIITPets -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/stanford_cars -d StanfordCars -a mnasnet1_0 -l classifier[-1] --save_features # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet50 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet101 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a resnet152 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a googlenet -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a inception_v3 --resizing res.299 -l fc --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet121 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet169 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a densenet201 -l classifier --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a mobilenet_v2 -l classifier[-1] --save_features CUDA_VISIBLE_DEVICES=0 python nce.py ./data/SUN397 -d SUN397 -a mnasnet1_0 -l classifier[-1] --save_features ================================================ FILE: examples/model_selection/requirements.txt ================================================ timm numba ================================================ FILE: examples/model_selection/utils.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import random import sys, os import torch import timm from torch.utils.data import Subset import torchvision.transforms as T import torch.nn.functional as F import torchvision.models as models sys.path.append('../../..') import tllib.vision.datasets as datasets class Logger(object): """Writes stream output to external text file. Args: filename (str): the file to write stream output stream: the stream to read from. Default: sys.stdout """ def __init__(self, data_name, model_name, metric_name, stream=sys.stdout): self.terminal = stream self.save_dir = os.path.join(data_name, model_name) # save intermediate features/outputs self.result_dir = os.path.join(data_name, f'{metric_name}.txt') # save ranking results os.makedirs(self.save_dir, exist_ok=True) self.log = open(self.result_dir, 'a') def write(self, message): self.terminal.write(message) self.log.write(message) self.flush() def get_save_dir(self): return self.save_dir def get_result_dir(self): return self.result_dir def flush(self): self.terminal.flush() self.log.flush() def close(self): self.terminal.close() self.log.close() def get_model_names(): return sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) + timm.list_models() def forwarding_dataset(score_loader, model, layer, device): """ A forward forcasting on full dataset :params score_loader: the dataloader for scoring transferability :params model: the model for scoring transferability :params layer: before which layer features are extracted, for registering hooks returns features: extracted features of model prediction: probability outputs of model targets: ground-truth labels of dataset """ features = [] outputs = [] targets = [] def hook_fn_forward(module, input, output): features.append(input[0].detach().cpu()) outputs.append(output.detach().cpu()) forward_hook = layer.register_forward_hook(hook_fn_forward) model.eval() with torch.no_grad(): for _, (data, target) in enumerate(score_loader): targets.append(target) data = data.to(device) _ = model(data) forward_hook.remove() features = torch.cat([x for x in features]).numpy() outputs = torch.cat([x for x in outputs]) predictions = F.softmax(outputs, dim=-1).numpy() targets = torch.cat([x for x in targets]).numpy() return features, predictions, targets def get_model(model_name, pretrained=True, pretrained_checkpoint=None): if model_name in get_model_names(): # load models from common.vision.models backbone = models.__dict__[model_name](pretrained=pretrained) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=pretrained) if pretrained_checkpoint: print("=> loading pre-trained model from '{}'".format(pretrained_checkpoint)) pretrained_dict = torch.load(pretrained_checkpoint) backbone.load_state_dict(pretrained_dict, strict=False) return backbone def get_dataset(dataset_name, root, transform, sample_rate=100, num_samples_per_classes=None, split='train'): """ When sample_rate < 100, e.g. sample_rate = 50, use 50% data to train the model. Otherwise, if num_samples_per_classes is not None, e.g. 5, then sample 5 images for each class, and use them to train the model; otherwise, keep all the data. """ dataset = datasets.__dict__[dataset_name] if sample_rate < 100: score_dataset = dataset(root=root, split=split, sample_rate=sample_rate, download=True, transform=transform) num_classes = len(score_dataset.classes) else: score_dataset = dataset(root=root, split=split, download=True, transform=transform) num_classes = len(score_dataset.classes) if num_samples_per_classes is not None: samples = list(range(len(score_dataset))) random.shuffle(samples) samples_len = min(num_samples_per_classes * num_classes, len(score_dataset)) print("Origin dataset:", len(score_dataset), "Sampled dataset:", samples_len, "Ratio:", float(samples_len) / len(score_dataset)) dataset = Subset(score_dataset, samples[:samples_len]) return score_dataset, num_classes def get_transform(resizing='res.'): """ resizing mode: - default: resize the image to 256 and take the center crop of size 224; – res.: resize the image to 224 – res.|crop: resize the image such that the smaller side is of size 256 and then take a central crop of size 224. """ if resizing == 'default': transform = T.Compose([ T.Resize(256), T.CenterCrop(224), ]) elif resizing == 'res.': transform = T.Resize((224, 224)) elif resizing == 'res.299': transform = T.Resize((299, 299)) elif resizing == 'res.|crop': transform = T.Compose([ T.Resize((256, 256)), T.CenterCrop(224), ]) else: raise NotImplementedError(resizing) return T.Compose([ transform, T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ================================================ FILE: examples/semi_supervised_learning/image_classification/README.md ================================================ # Semi-Supervised Learning for Image Classification ## Installation It’s suggested to use **pytorch==1.7.1** and torchvision==0.8.2 in order to reproduce the benchmark results. Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You also need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset Following datasets can be downloaded automatically: - [FOOD-101](https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101/) - [CIFAR10](http://www.cs.utoronto.ca/~kriz/cifar.html) - [CIFAR100](http://www.cs.utoronto.ca/~kriz/cifar.html) - [CUB200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) - [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) - [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) - [SUN397](https://vision.princeton.edu/projects/2010/SUN/) - [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html) - [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/) - [OxfordFlowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) - [Caltech101](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) ## Supported Methods Supported methods include: - [Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks (Pseudo Label, ICML 2013)](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.664.3543&rep=rep1&type=pdf) - [Temporal Ensembling for Semi-Supervised Learning (Pi Model, ICLR 2017)](https://arxiv.org/abs/1610.02242) - [Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (Mean Teacher, NIPS 2017)](https://arxiv.org/abs/1703.01780) - [Self-Training With Noisy Student Improves ImageNet Classification (Noisy Student, CVPR 2020)](https://openaccess.thecvf.com/content_CVPR_2020/papers/Xie_Self-Training_With_Noisy_Student_Improves_ImageNet_Classification_CVPR_2020_paper.pdf) - [Unsupervised Data Augmentation for Consistency Training (UDA, NIPS 2020)](https://arxiv.org/pdf/1904.12848v4.pdf) - [FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence (FixMatch, NIPS 2020)](https://arxiv.org/abs/2001.07685) - [Self-Tuning for Data-Efficient Deep Learning (Self-Tuning, ICML 2021)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/Self-Tuning-for-Data-Efficient-Deep-Learning-icml21.pdf) - [FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling (FlexMatch, NIPS 2021)](https://arxiv.org/abs/2110.08263) - [Debiased Learning From Naturally Imbalanced Pseudo-Labels (DebiasMatch, CVPR 2022)](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_Debiased_Learning_From_Naturally_Imbalanced_Pseudo-Labels_CVPR_2022_paper.pdf) - [Debiased Self-Training for Semi-Supervised Learning (DST)](https://arxiv.org/abs/2202.07136) ## Usage ### Semi-supervised learning with supervised pre-trained model The shell files give the script to train with supervised pre-trained model with specified hyper-parameters. For example, if you want to train UDA on CIFAR100, use the following script ```shell script # Semi-supervised learning on CIFAR100 (ResNet50, 400labels). # Assume you have put the datasets under the path `data/cifar100`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cifar100_4_labels_per_class ``` Following common practice in semi-supervised learning, we select a class-balanced subset as the labeled dataset and treat other samples as unlabeled data. In the above command, `num-samples-per-class` specifies how many labeled samples for each class. Note that the labeled subset is **deterministic with the same random seed**. Hence, if you want to compare different algorithms with the same labeled subset, you can simply pass in the same random seed. ### Semi-supervised learning with unsupervised pre-trained model Take MoCo as an example. 1. Download MoCo pretrained checkpoints from https://github.com/facebookresearch/moco 2. Convert the format of the MoCo checkpoints to the standard format of pytorch ```shell mkdir checkpoints python convert_moco_to_pretrained.py checkpoints/moco_v2_800ep_pretrain.pth.tar checkpoints/moco_v2_800ep_backbone.pth checkpoints/moco_v2_800ep_fc.pth ``` 3. Start training ```shell CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_4_labels_per_class ``` ## Experiment and Results **Notations** - ``Avg`` is the accuracy reported by `TLlib`. - ``ERM`` refers to the model trained with only labeled data. - ``Oracle`` refers to the model trained using all data as labeled data. Below are the results of implemented methods. Other than _Oracle_, we randomly sample 4 labels per category. ### ImageNet Supervised Pre-training (ResNet-50) | Methods | Food101 | CIFAR10 | CIFAR100 | CUB200 | Aircraft | Cars | SUN397 | DTD | Pets | Flowers | Caltech | Avg | |--------------|---------|---------|----------|--------|----------|------|--------|------|------|---------|---------|------| | ERM | 33.6 | 59.4 | 47.9 | 48.6 | 29.0 | 37.1 | 40.9 | 50.5 | 82.2 | 87.6 | 82.2 | 54.5 | | Pseudo Label | 36.9 | 62.8 | 52.5 | 54.9 | 30.4 | 40.4 | 41.7 | 54.1 | 89.6 | 93.5 | 85.1 | 58.4 | | Pi Model | 34.2 | 66.9 | 48.5 | 47.9 | 26.7 | 37.4 | 40.9 | 51.9 | 83.5 | 92.0 | 82.2 | 55.6 | | Mean Teacher | 40.4 | 78.1 | 58.5 | 52.8 | 32.0 | 45.6 | 40.2 | 53.8 | 86.8 | 92.8 | 83.7 | 60.4 | | UDA | 41.9 | 73.0 | 59.8 | 55.4 | 33.5 | 42.7 | 42.1 | 49.7 | 88.0 | 93.4 | 85.3 | 60.4 | | FixMatch | 36.2 | 74.5 | 58.0 | 52.6 | 27.1 | 44.8 | 40.8 | 50.2 | 87.8 | 93.6 | 83.2 | 59.0 | | Self Tuning | 41.4 | 70.9 | 57.2 | 60.5 | 37.0 | 59.8 | 43.5 | 51.7 | 88.4 | 93.5 | 89.1 | 63.0 | | FlexMatch | 48.1 | 94.2 | 69.2 | 65.1 | 38.0 | 55.3 | 50.2 | 55.6 | 91.5 | 94.6 | 89.4 | 68.3 | | DebiasMatch | 57.1 | 92.4 | 69.0 | 66.2 | 41.5 | 65.4 | 48.3 | 54.2 | 90.2 | 95.4 | 89.3 | 69.9 | | DST | 58.1 | 93.5 | 67.8 | 68.6 | 44.9 | 68.6 | 47.0 | 56.3 | 91.5 | 95.1 | 90.3 | 71.1 | | Oracle | 85.5 | 97.5 | 86.3 | 81.1 | 85.1 | 91.1 | 64.1 | 68.8 | 93.2 | 98.1 | 92.6 | 85.8 | ### ImageNet Unsupervised Pre-training (ResNet-50, MoCo v2) | Methods | Food101 | CIFAR10 | CIFAR100 | CUB200 | Aircraft | Cars | SUN397 | DTD | Pets | Flowers | Caltech | Avg | |--------------|---------|---------|----------|--------|----------|------|--------|------|------|---------|---------|------| | ERM | 33.5 | 63.0 | 50.8 | 39.4 | 28.1 | 40.3 | 40.7 | 53.7 | 65.4 | 87.5 | 82.8 | 53.2 | | Pseudo Label | 33.6 | 71.9 | 53.8 | 42.7 | 30.9 | 51.2 | 41.2 | 55.2 | 69.3 | 94.2 | 86.2 | 57.3 | | Pi Model | 32.7 | 77.9 | 50.9 | 33.6 | 27.2 | 34.4 | 41.1 | 54.9 | 66.7 | 91.4 | 84.1 | 54.1 | | Mean Teacher | 36.8 | 79.0 | 56.7 | 43.0 | 33.0 | 53.9 | 39.5 | 54.5 | 67.8 | 92.7 | 83.3 | 58.2 | | UDA | 39.5 | 91.3 | 60.0 | 41.9 | 36.2 | 39.7 | 41.7 | 51.5 | 71.0 | 93.7 | 86.5 | 59.4 | | FixMatch | 44.3 | 86.1 | 58.0 | 42.7 | 38.0 | 55.4 | 42.4 | 53.1 | 67.9 | 95.2 | 83.4 | 60.6 | | Self Tuning | 34.0 | 63.6 | 51.7 | 43.3 | 32.2 | 50.2 | 40.7 | 52.7 | 68.2 | 91.8 | 87.7 | 56.0 | | FlexMatch | 50.2 | 96.6 | 69.2 | 49.4 | 41.3 | 62.5 | 47.2 | 54.5 | 72.4 | 94.8 | 89.4 | 66.1 | | DebiasMatch | 54.2 | 95.5 | 68.1 | 49.1 | 40.9 | 73.0 | 47.6 | 54.4 | 76.6 | 95.5 | 88.7 | 67.6 | | DST | 57.1 | 95.0 | 68.2 | 53.6 | 47.7 | 72.0 | 46.8 | 56.0 | 76.3 | 95.6 | 90.1 | 68.9 | | Oracle | 87.0 | 98.2 | 87.9 | 80.6 | 88.7 | 92.7 | 63.9 | 73.8 | 90.6 | 97.8 | 93.1 | 86.8 | ## TODO 1. support multi-gpu training 2. add training from scratch code and results ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{pseudo_label, title={Pseudo-label: The simple and efficient semi-supervised learning method for deep neural networks}, author={Lee, Dong-Hyun and others}, booktitle={ICML}, year={2013} } @inproceedings{pi_model, title={Temporal ensembling for semi-supervised learning}, author={Laine, Samuli and Aila, Timo}, booktitle={ICLR}, year={2017} } @inproceedings{mean_teacher, title={Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results}, author={Tarvainen, Antti and Valpola, Harri}, booktitle={NIPS}, year={2017} } @inproceedings{noisy_student, title={Self-training with noisy student improves imagenet classification}, author={Xie, Qizhe and Luong, Minh-Thang and Hovy, Eduard and Le, Quoc V}, booktitle={CVPR}, year={2020} } @inproceedings{UDA, title={Unsupervised data augmentation for consistency training}, author={Xie, Qizhe and Dai, Zihang and Hovy, Eduard and Luong, Thang and Le, Quoc}, booktitle={NIPS}, year={2020} } @inproceedings{FixMatch, title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence}, author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang}, booktitle={NIPS}, year={2020} } @inproceedings{SelfTuning, title={Self-tuning for data-efficient deep learning}, author={Wang, Ximei and Gao, Jinghan and Long, Mingsheng and Wang, Jianmin}, booktitle={ICML}, year={2021} } @inproceedings{FlexMatch, title={Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling}, author={Zhang, Bowen and Wang, Yidong and Hou, Wenxin and Wu, Hao and Wang, Jindong and Okumura, Manabu and Shinozaki, Takahiro}, booktitle={NeurIPS}, year={2021} } @inproceedings{DebiasMatch, title={Debiased Learning from Naturally Imbalanced Pseudo-Labels}, author={Wang, Xudong and Wu, Zhirong and Lian, Long and Yu, Stella X}, booktitle={CVPR}, year={2022} } @article{DST, title={Debiased Self-Training for Semi-Supervised Learning}, author={Chen, Baixu and Jiang, Junguang and Wang, Ximei and Wang, Jianmin and Long, Mingsheng}, journal={arXiv preprint arXiv:2202.07136}, year={2022} } ``` ================================================ FILE: examples/semi_supervised_learning/image_classification/convert_moco_to_pretrained.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import sys import torch if __name__ == "__main__": input = sys.argv[1] obj = torch.load(input, map_location="cpu") obj = obj["state_dict"] newmodel = {} fc = {} for k, v in obj.items(): if not k.startswith("module.encoder_q."): continue old_k = k k = k.replace("module.encoder_q.", "") if k.startswith("fc"): print(k) fc[k] = v else: newmodel[k] = v with open(sys.argv[2], "wb") as f: torch.save(newmodel, f) with open(sys.argv[3], "wb") as f: torch.save(fc, f) ================================================ FILE: examples/semi_supervised_learning/image_classification/debiasmatch.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) unlabeled_train_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('unlabeled_train_transform: ', unlabeled_train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=unlabeled_train_transform, seed=args.seed) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # initialize q_hat q_hat = (torch.ones(num_classes) / num_classes).to(device) # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, q_hat, epoch, args) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD, lr_scheduler: LambdaLR, q_hat, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') self_training_losses = AverageMeter('Self Training Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f') pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs, pseudo_label_ratios], prefix="Epoch: [{}]".format(epoch)) self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) (x_u, x_u_strong), labels_u = next(unlabeled_train_iter) x_u = x_u.to(device) x_u_strong = x_u_strong.to(device) labels_u = labels_u.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output # cross entropy loss y_l = model(x_l) y_l_strong = model(x_l_strong) cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss.backward() # self training loss with torch.no_grad(): y_u = model(x_u) y_u_strong = model(x_u_strong) # update q_hat q = torch.softmax(y_u, dim=1).mean(dim=0) q_hat = args.momentum * q_hat + (1 - args.momentum) * q self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong + args.tau * torch.log(q_hat), y_u - args.tau * torch.log(q_hat)) self_training_loss = args.trade_off_self_training * self_training_loss self_training_loss.backward() # measure accuracy and record loss loss = cls_loss + self_training_loss losses.update(loss.item(), batch_size) cls_losses.update(cls_loss.item(), batch_size) self_training_losses.update(self_training_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # ratio of pseudo labels n_pseudo_labels = mask.sum() ratio = n_pseudo_labels / batch_size pseudo_label_ratios.update(ratio.item() * 100, batch_size) # accuracy of pseudo labels if n_pseudo_labels > 0: pseudo_labels = pseudo_labels * mask - (1 - mask) n_correct = (pseudo_labels == labels_u).float().sum() pseudo_label_acc = n_correct / n_pseudo_labels * 100 pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='DebiasMatch for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--momentum', default=0.999, type=float, help='momentum coefficient for updating q_hat (default: 0.999)') parser.add_argument('--tau', default=1, type=float, help='debiased strength (default: 1)') parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-self-training', default=1, type=float, help='the trade-off hyper-parameter of self training loss') parser.add_argument('--threshold', default=0.95, type=float, help='confidence threshold') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run (default: 90)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='debiasmatch', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/debiasmatch.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.8 --tau 1 --seed 0 --log logs/debiasmatch/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --threshold 0.9 --tau 1 --seed 0 --log logs/debiasmatch/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.7 --tau 1 --seed 0 --log logs/debiasmatch/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 7 --tau 3 --seed 0 --log logs/debiasmatch/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.9 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --tau 1 --seed 0 --log logs/debiasmatch_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python debiasmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --tau 3 --seed 0 --log logs/debiasmatch_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/dst.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss from tllib.self_training.dst import ImageClassifier, WorstCaseEstimationLoss from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) unlabeled_train_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('unlabeled_train_transform: ', unlabeled_train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=unlabeled_train_transform, seed=args.seed) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, width=args.width, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') self_training_losses = AverageMeter('Self Training Loss', ':3.2f') wce_losses = AverageMeter('Worst Case Estimation Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f') pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, self_training_losses, wce_losses, cls_accs, pseudo_label_accs, pseudo_label_ratios], prefix="Epoch: [{}]".format(epoch)) self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device) worst_case_estimation_criterion = WorstCaseEstimationLoss(args.eta_prime).to(device) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) (x_u, x_u_strong), labels_u = next(unlabeled_train_iter) x_u = x_u.to(device) x_u_strong = x_u_strong.to(device) labels_u = labels_u.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output # ============================================================================================================== # cross entropy loss (strong augment) # ============================================================================================================== y_l_strong, _, _ = model(x_l_strong) cls_loss_strong = args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss_strong.backward() x = torch.cat((x_l, x_u), dim=0) outputs, outputs_adv, _ = model(x) y_l, y_u = outputs.chunk(2, dim=0) y_l_adv, y_u_adv = outputs_adv.chunk(2, dim=0) # ============================================================================================================== # cross entropy loss (weak augment) # ============================================================================================================== cls_loss_weak = F.cross_entropy(y_l, labels_l) # ============================================================================================================== # worst case estimation loss # ============================================================================================================== wce_loss = args.eta * worst_case_estimation_criterion(y_l, y_l_adv, y_u, y_u_adv) (cls_loss_weak + wce_loss).backward() # ============================================================================================================== # self training loss # ============================================================================================================== _, _, y_u_strong = model(x_u_strong) self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong, y_u) self_training_loss = args.trade_off_self_training * self_training_loss self_training_loss.backward() # measure accuracy and record loss cls_loss = cls_loss_strong + cls_loss_weak cls_losses.update(cls_loss.item(), batch_size) loss = cls_loss + self_training_loss + wce_loss losses.update(loss.item(), batch_size) wce_losses.update(wce_loss.item(), batch_size) self_training_losses.update(self_training_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # ratio of pseudo labels n_pseudo_labels = mask.sum() ratio = n_pseudo_labels / batch_size pseudo_label_ratios.update(ratio.item() * 100, batch_size) # accuracy of pseudo labels if n_pseudo_labels > 0: pseudo_labels = pseudo_labels * mask - (1 - mask) n_correct = (pseudo_labels == labels_u).float().sum() pseudo_label_acc = n_correct / n_pseudo_labels * 100 pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() model.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Debiased Self-Training for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--width', default=2048, type=int, help='width of the pseudo head and the worst-case estimation head') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-self-training', default=1, type=float, help='the trade-off hyper-parameter of self training loss') parser.add_argument('--eta', default=1, type=float, help='the trade-off hyper-parameter of adversarial loss') parser.add_argument('--eta-prime', default=2, type=float, help="the trade-off hyper-parameter between adversarial loss on labeled data " "and that on unlabeled data") parser.add_argument('--threshold', default=0.7, type=float, help='confidence threshold') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0002, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run (default: 90)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='dst', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/dst.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python dst.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.8 --trade-off-self-training 1 --eta-prime 2 \ --seed 0 --log logs/dst/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python dst.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \ --seed 0 --log logs/dst/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python dst.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \ --seed 0 --log logs/dst/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python dst.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.95 --trade-off-self-training 0.3 --eta-prime 2 \ --seed 0 --log logs/dst/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python dst.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \ --seed 0 --log logs/dst/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python dst.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \ --seed 0 --log logs/dst/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python dst.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \ --seed 0 --log logs/dst/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python dst.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.95 --trade-off-self-training 1 --eta-prime 2 \ --seed 0 --log logs/dst/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python dst.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.9 --trade-off-self-training 0.3 --eta-prime 2 \ --seed 0 --log logs/dst/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python dst.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.9 --trade-off-self-training 0.3 --eta-prime 1 \ --seed 0 --log logs/dst/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python dst.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.7 --trade-off-self-training 1 --eta-prime 4 \ --seed 0 --log logs/dst/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python dst.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \ --seed 0 --log logs/dst_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python dst.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 1 --eta-prime 2 \ --seed 0 --log logs/dst_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python dst.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \ --seed 0 --log logs/dst_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python dst.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 2 \ --seed 0 --log logs/dst_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python dst.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \ --seed 0 --log logs/dst_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python dst.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 1 --eta-prime 1 \ --seed 0 --log logs/dst_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python dst.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 0.3 --eta-prime 2 \ --seed 0 --log logs/dst_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python dst.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 0.1 --eta-prime 3 \ --seed 0 --log logs/dst_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python dst.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --trade-off-self-training 0.1 --eta-prime 1 \ --seed 0 --log logs/dst_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python dst.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 1 --eta-prime 1 \ --seed 0 --log logs/dst_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python dst.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --trade-off-self-training 0.1 --eta-prime 1 \ --seed 0 --log logs/dst_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/erm.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, ConcatDataset import utils from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) train_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('train_transform: ', train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, train_transform, val_transform, seed=args.seed) if args.oracle: num_classes = labeled_train_dataset.num_classes labeled_train_dataset = ConcatDataset([labeled_train_dataset, unlabeled_train_dataset]) labeled_train_dataset.num_classes = num_classes print("labeled_dataset_size: ", len(labeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch utils.empirical_risk_minimization(labeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args, device) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() if __name__ == '__main__': parser = argparse.ArgumentParser(description='Baseline for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') parser.add_argument('--oracle', action='store_true', default=False, help='use all data as labeled data (oracle)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run (default: 20)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='baseline', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/erm.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/erm/food101_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 10 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/erm/food101_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --oracle -a resnet50 \ --lr 0.01 --finetune --epochs 80 --seed 0 --log logs/erm/food101_oracle # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/cifar10_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 10 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/cifar10_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --oracle -a resnet50 \ --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/cifar10_oracle # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/erm/cifar100_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 10 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/erm/cifar100_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --oracle -a resnet50 \ --lr 0.01 --finetune --epochs 80 --seed 0 --log logs/erm/cifar100_oracle # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/erm/cub200_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 10 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/erm/cub200_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --oracle -a resnet50 \ --lr 0.003 --finetune --epochs 80 --seed 0 --log logs/erm/cub200_oracle # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/aircraft_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 10 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/aircraft_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --oracle -a resnet50 \ --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/aircraft_oracle # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/car_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 10 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/car_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --oracle -a resnet50 \ --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/car_oracle # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --seed 0 --log logs/erm/sun_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 10 -a resnet50 \ --lr 0.001 --finetune --seed 0 --log logs/erm/sun_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --oracle -a resnet50 \ --lr 0.001 --finetune --epochs 80 --seed 0 --log logs/erm/sun_oracle # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/dtd_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 10 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/dtd_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --oracle -a resnet50 \ --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/dtd_oracle # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --seed 0 --log logs/erm/pets_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 10 -a resnet50 \ --lr 0.001 --finetune --seed 0 --log logs/erm/pets_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --oracle -a resnet50 \ --lr 0.001 --finetune --epochs 80 --seed 0 --log logs/erm/pets_oracle # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/flowers_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 10 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/erm/flowers_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --oracle -a resnet50 \ --lr 0.03 --finetune --epochs 80 --seed 0 --log logs/erm/flowers_oracle # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/erm/caltech_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 10 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/erm/caltech_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --oracle -a resnet50 \ --lr 0.003 --finetune --epochs 80 --seed 0 --log logs/erm/caltech_oracle # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/food101_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/food101_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/food101 -d Food101 --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/food101_oracle # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar10_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar10_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cifar10_oracle # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cifar100_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cifar100_oracle # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cub200_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/cub200_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/cub200_oracle # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/aircraft_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/aircraft_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/aircraft_oracle # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/car_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/car_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/cars -d StanfordCars --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/car_oracle # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/sun_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/sun_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/sun_oracle # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/dtd_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/dtd_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/dtd_oracle # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/pets_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/pets_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/pets -d OxfordIIITPets --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/pets_oracle # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/flowers_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/flowers_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/flowers -d OxfordFlowers102 --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/flowers_oracle # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/caltech_4_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --num-samples-per-class 10 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/erm_moco_pretrain/caltech_10_labels_per_class CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --oracle -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --epochs 80 --seed 0 --log logs/erm_moco_pretrain/caltech_oracle ================================================ FILE: examples/semi_supervised_learning/image_classification/fixmatch.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) unlabeled_train_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('unlabeled_train_transform: ', unlabeled_train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=unlabeled_train_transform, seed=args.seed) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') self_training_losses = AverageMeter('Self Training Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f') pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs, pseudo_label_ratios], prefix="Epoch: [{}]".format(epoch)) self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) (x_u, x_u_strong), labels_u = next(unlabeled_train_iter) x_u = x_u.to(device) x_u_strong = x_u_strong.to(device) labels_u = labels_u.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output # cross entropy loss y_l = model(x_l) y_l_strong = model(x_l_strong) cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss.backward() # self training loss with torch.no_grad(): y_u = model(x_u) y_u_strong = model(x_u_strong) self_training_loss, mask, pseudo_labels = self_training_criterion(y_u_strong, y_u) self_training_loss = args.trade_off_self_training * self_training_loss self_training_loss.backward() # measure accuracy and record loss loss = cls_loss + self_training_loss losses.update(loss.item(), batch_size) cls_losses.update(cls_loss.item(), batch_size) self_training_losses.update(self_training_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # ratio of pseudo labels n_pseudo_labels = mask.sum() ratio = n_pseudo_labels / batch_size pseudo_label_ratios.update(ratio.item() * 100, batch_size) # accuracy of pseudo labels if n_pseudo_labels > 0: pseudo_labels = pseudo_labels * mask - (1 - mask) n_correct = (pseudo_labels == labels_u).float().sum() pseudo_label_acc = n_correct / n_pseudo_labels * 100 pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='FixMatch for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-self-training', default=1, type=float, help='the trade-off hyper-parameter of self training loss') parser.add_argument('--threshold', default=0.95, type=float, help='confidence threshold') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run (default: 60)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='fixmatch', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/fixmatch.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/fixmatch/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/fixmatch/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.9 --seed 0 --log logs/fixmatch/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/fixmatch/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/fixmatch_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/fixmatch_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/fixmatch_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/fixmatch_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/fixmatch_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/fixmatch_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python fixmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/fixmatch_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/flexmatch.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.self_training.flexmatch import DynamicThresholdingModule from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) unlabeled_train_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('unlabeled_train_transform: ', unlabeled_train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=unlabeled_train_transform, seed=args.seed) unlabeled_train_dataset = utils.convert_dataset(unlabeled_train_dataset) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes args.num_classes = num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # thresholding module with convex mapping function x / (2 - x) thresholding_module = DynamicThresholdingModule(args.threshold, args.warmup, lambda x: x / (2 - x), num_classes, len(unlabeled_train_dataset), device=device) # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, thresholding_module, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, thresholding_module: DynamicThresholdingModule, model, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') self_training_losses = AverageMeter('Self Training Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f') pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs, pseudo_label_ratios], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) idx_u, ((x_u, x_u_strong), labels_u) = next(unlabeled_train_iter) idx_u = idx_u.to(device) x_u = x_u.to(device) x_u_strong = x_u_strong.to(device) labels_u = labels_u.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output # cross entropy loss y_l = model(x_l) y_l_strong = model(x_l_strong) cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss.backward() # self training loss with torch.no_grad(): y_u = model(x_u) y_u_strong = model(x_u_strong) confidence, pseudo_labels = torch.softmax(y_u, dim=1).max(dim=1) dynamic_threshold = thresholding_module.get_threshold(pseudo_labels) mask = (confidence > dynamic_threshold).float() # mask used for updating learning status selected_mask = (confidence > args.threshold).long() thresholding_module.update(idx_u, selected_mask, pseudo_labels) self_training_loss = args.trade_off_self_training * ( F.cross_entropy(y_u_strong, pseudo_labels, reduction='none') * mask).mean() self_training_loss.backward() # measure accuracy and record loss loss = cls_loss + self_training_loss losses.update(loss.item(), batch_size) cls_losses.update(cls_loss.item(), batch_size) self_training_losses.update(self_training_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # ratio of pseudo labels n_pseudo_labels = mask.sum() ratio = n_pseudo_labels / batch_size pseudo_label_ratios.update(ratio.item() * 100, batch_size) # accuracy of pseudo labels if n_pseudo_labels > 0: pseudo_labels = pseudo_labels * mask - (1 - mask) n_correct = (pseudo_labels == labels_u).float().sum() pseudo_label_acc = n_correct / n_pseudo_labels * 100 pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='FlexMatch for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--warmup', default=False, type=bool) parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-self-training', default=1, type=float, help='the trade-off hyper-parameter of self training loss') parser.add_argument('--threshold', default=0.95, type=float, help='confidence threshold') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=90, type=int, metavar='N', help='number of total epochs to run (default: 90)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='flexmatch', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/flexmatch.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.9 --seed 0 --log logs/flexmatch/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/flexmatch/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.9 --seed 0 --log logs/flexmatch/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.95 --seed 0 --log logs/flexmatch/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/flexmatch_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.9 --seed 0 --log logs/flexmatch_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python flexmatch.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/flexmatch_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/mean_teacher.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.self_training.pi_model import sigmoid_warm_up, L2ConsistencyLoss from tllib.self_training.mean_teacher import update_bn, EMATeacher from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) unlabeled_train_transform = MultipleApply([weak_augment, weak_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('unlabeled_train_transform: ', unlabeled_train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=unlabeled_train_transform, seed=args.seed) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) teacher = EMATeacher(classifier, alpha=args.alpha) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, classifier, teacher, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, teacher, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') con_losses = AverageMeter('Con Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, con_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) consistency_criterion = L2ConsistencyLoss(reduction='sum').to(device) # switch to train mode model.train() teacher.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) (x_u, x_u_teacher), _ = next(unlabeled_train_iter) x_u = x_u.to(device) x_u_teacher = x_u_teacher.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output # cross entropy loss y_l = model(x_l) y_l_strong = model(x_l_strong) cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss.backward() # consistency loss with torch.no_grad(): y_u_teacher = teacher(x_u_teacher) p_u_teacher = torch.softmax(y_u_teacher, dim=1) y_u = model(x_u) p_u = torch.softmax(y_u, dim=1) con_loss = args.trade_off_con * sigmoid_warm_up(epoch, args.warm_up_epochs) * \ consistency_criterion(p_u, p_u_teacher) con_loss.backward() # measure accuracy and record loss loss = cls_loss + con_loss losses.update(loss.item(), batch_size) cls_losses.update(cls_loss.item(), batch_size) con_losses.update(con_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # update teacher global_step = epoch * args.iters_per_epoch + i + 1 teacher.set_alpha(min(args.alpha, 1 - 1 / global_step)) teacher.update() update_bn(model, teacher.teacher) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Mean Teacher for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') parser.add_argument('--alpha', default=0.999, type=float, help='ema decay factor') # training parameters parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-con', default=0.1, type=float, help='the trade-off hyper-parameter of consistency loss') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0001, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run (default: 40)') parser.add_argument('--warm-up-epochs', default=10, type=int, help='number of epochs to warm up (default: 10)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='mean_teacher', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/mean_teacher.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.05 --finetune --seed 0 --log logs/mean_teacher/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.05 --finetune --seed 0 --log logs/mean_teacher/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/mean_teacher/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/mean_teacher/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/mean_teacher/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/mean_teacher/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/mean_teacher/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python mean_teacher.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/mean_teacher_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/noisy_student.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import copy import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.vision.models.reid.loss import CrossEntropyLoss from tllib.modules.classifier import Classifier from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class ImageClassifier(Classifier): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) bottleneck[0].weight.data.normal_(0, 0.005) bottleneck[0].bias.data.fill_(0.1) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) self.dropout = nn.Dropout(0.5) self.as_teacher_model = False def forward(self, x: torch.Tensor): """""" f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) if not self.as_teacher_model: f = self.dropout(f) predictions = self.head(f) return predictions def calc_teacher_output(classifier_teacher: ImageClassifier, weak_augmented_unlabeled_dataset): """Compute outputs of the teacher network. Here, we use weak data augmentation and do not introduce an additional dropout layer according to the Noisy Student paper `Self-Training With Noisy Student Improves ImageNet Classification `_. """ data_loader = DataLoader(weak_augmented_unlabeled_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False) batch_time = AverageMeter('Time', ':6.3f') progress = ProgressMeter( len(data_loader), [batch_time], prefix='Computing teacher output: ') teacher_output = [] with torch.no_grad(): end = time.time() for i, (images, _) in enumerate(data_loader): images = images.to(device) output = classifier_teacher(images) teacher_output.append(output) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) teacher_output = torch.cat(teacher_output, dim=0) return teacher_output def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('weak_augment (input transform for teacher model): ', weak_augment) print('strong_augment (input transform for student model): ', strong_augment) print('val_transform:', val_transform) labeled_train_dataset, weak_augmented_unlabeled_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=weak_augment, seed=args.seed) _, strong_augmented_unlabeled_dataset, _ = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=strong_augment, seed=args.seed) strong_augmented_unlabeled_dataset = utils.convert_dataset(strong_augmented_unlabeled_dataset) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(weak_augmented_unlabeled_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(strong_augmented_unlabeled_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) if args.pretrained_teacher: # load teacher model classifier_teacher = copy.deepcopy(classifier) checkpoint = torch.load(args.pretrained_teacher) classifier_teacher.load_state_dict(checkpoint) classifier_teacher.eval() classifier_teacher.as_teacher_model = True print('compute outputs of the teacher network') teacher_output = calc_teacher_output(classifier_teacher, weak_augmented_unlabeled_dataset) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch if args.pretrained_teacher: train(labeled_train_iter, unlabeled_train_iter, classifier, teacher_output, optimizer, lr_scheduler, epoch, args) else: utils.empirical_risk_minimization(labeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args, device) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, teacher_output, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') self_training_losses = AverageMeter('Self Training Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) self_training_criterion = CrossEntropyLoss().to(device) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) idx_u, (x_u_strong, _) = next(unlabeled_train_iter) idx_u = idx_u.to(device) x_u_strong = x_u_strong.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output y_l = model(x_l) y_l_strong = model(x_l_strong) # cross entropy loss cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss.backward() # self training loss y_u = teacher_output[idx_u] y_u_strong = model(x_u_strong) self_training_loss = args.trade_off_self_training * self_training_criterion(y_u_strong / args.T, y_u / args.T) self_training_loss.backward() # measure accuracy and record loss loss = cls_loss + self_training_loss losses.update(loss.item(), batch_size) cls_losses.update(cls_loss.item(), batch_size) self_training_losses.update(self_training_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Noisy Student for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') parser.add_argument('--pretrained-teacher', default=None, type=str, help='pretrained checkpoint of the teacher model') # training parameters parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-self-training', default=1, type=float, help='the trade-off hyper-parameter of self training loss') parser.add_argument('--T', default=2, type=float, help='temperature') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run (default: 40)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='noisy_student', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/noisy_student.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --epochs 20 --seed 0 --log logs/noisy_student/cifar100_4_labels_per_class/iter_0 for round in 0 1 2; do CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-teacher logs/noisy_student/cifar100_4_labels_per_class/iter_$round/checkpoints/latest.pth \ --lr 0.01 --finetune --epochs 40 --T 0.5 --seed 0 --log logs/noisy_student/cifar100_4_labels_per_class/iter_$((round + 1)) done # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # CIFAR100 CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --epochs 20 --seed 0 \ --log logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_0 for round in 0 1 2; do CUDA_VISIBLE_DEVICES=0 python noisy_student.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --pretrained-teacher logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_$round/checkpoints/latest.pth \ --lr 0.001 --finetune --lr-scheduler cos --epochs 40 --T 1 --seed 0 \ --log logs/noisy_student_moco_pretrain/cifar100_4_labels_per_class/iter_$((round + 1)) done ================================================ FILE: examples/semi_supervised_learning/image_classification/pi_model.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.self_training.pi_model import sigmoid_warm_up, L2ConsistencyLoss from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) unlabeled_train_transform = MultipleApply([weak_augment, weak_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('unlabeled_train_transform: ', unlabeled_train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=unlabeled_train_transform, seed=args.seed) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') con_losses = AverageMeter('Con Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, con_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) consistency_criterion = L2ConsistencyLoss().to(device) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) (x_u1, x_u2), _ = next(unlabeled_train_iter) x_u1 = x_u1.to(device) x_u2 = x_u2.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output # cross entropy loss y_l = model(x_l) y_l_strong = model(x_l_strong) cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss.backward() # consistency loss y_u1 = model(x_u1) y_u2 = model(x_u2) p_u1 = torch.softmax(y_u1, dim=1) p_u2 = torch.softmax(y_u2, dim=1) con_loss = args.trade_off_con * sigmoid_warm_up(epoch, args.warm_up_epochs) * consistency_criterion(p_u1, p_u2) con_loss.backward() # measure accuracy and record loss loss = cls_loss + con_loss losses.update(loss.item(), batch_size) cls_losses.update(cls_loss.item(), batch_size) con_losses.update(con_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Pi Model for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-con', default=0.1, type=float, help='the trade-off hyper-parameter of consistency loss') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run (default: 40)') parser.add_argument('--warm-up-epochs', default=10, type=int, help='number of epochs to warm up (default: 10)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='pi_model', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/pi_model.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/pi_model/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/pi_model/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/pi_model/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/pi_model/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python pi_model.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/pi_model/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/pi_model/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/pi_model/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python pi_model.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/pi_model/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python pi_model.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --seed 0 --log logs/pi_model/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python pi_model.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --seed 0 --log logs/pi_model/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/pi_model/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python pi_model.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python pi_model.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python pi_model.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python pi_model.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python pi_model.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python pi_model.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --seed 0 --log logs/pi_model_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/pseudo_label.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) unlabeled_train_transform = weak_augment val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('unlabeled_train_transform: ', unlabeled_train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=unlabeled_train_transform, seed=args.seed) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') self_training_losses = AverageMeter('Self Training Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs], prefix="Epoch: [{}]".format(epoch)) self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) x_u, labels_u = next(unlabeled_train_iter) x_u = x_u.to(device) labels_u = labels_u.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output # cross entropy loss y_l = model(x_l) y_l_strong = model(x_l_strong) cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss.backward() # self training loss y_u = model(x_u) self_training_loss, mask, pseudo_labels = self_training_criterion(y_u, y_u) self_training_loss = args.trade_off_self_training * self_training_loss self_training_loss.backward() # measure accuracy and record loss loss = cls_loss + self_training_loss losses.update(loss.item(), batch_size) cls_losses.update(cls_loss.item(), batch_size) self_training_losses.update(self_training_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # accuracy of pseudo labels n_pseudo_labels = mask.sum() if n_pseudo_labels > 0: pseudo_labels = pseudo_labels * mask - (1 - mask) n_correct = (pseudo_labels == labels_u).float().sum() pseudo_label_acc = n_correct / n_pseudo_labels * 100 pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Pseudo Label for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-self-training', default=1, type=float, help='the trade-off hyper-parameter of self training loss') parser.add_argument('--threshold', default=0.95, type=float, help='confidence threshold (default: 0.95)') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=40, type=int, metavar='N', help='number of total epochs to run (default: 40)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='pseudo_label', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/pseudo_label.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.8 --seed 0 --log logs/pseudo_label/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/pseudo_label/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.03 --finetune --threshold 0.95 --seed 0 --log logs/pseudo_label/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/pseudo_label/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/pseudo_label_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.03 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python pseudo_label.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --lr-scheduler cos --threshold 0.95 --seed 0 --log logs/pseudo_label_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/requirements.txt ================================================ timm ================================================ FILE: examples/semi_supervised_learning/image_classification/self_tuning.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import utils from tllib.self_training.self_tuning import Classifier, SelfTuning from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) train_transform = MultipleApply([strong_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('train_transform: ', train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, train_transform, val_transform, seed=args.seed) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) num_classes = labeled_train_dataset.num_classes backbone_q = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) pool_layer = nn.Identity() if args.no_pool else None classifier_q = Classifier(backbone_q, num_classes, projection_dim=args.projection_dim, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier_q) backbone_k = utils.get_model(args.arch) classifier_k = Classifier(backbone_k, num_classes, projection_dim=args.projection_dim, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer).to(device) selftuning = SelfTuning(classifier_q, classifier_k, num_classes, K=args.K, m=args.m, T=args.T).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier_q.get_parameters(args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.milestones, gamma=args.lr_gamma) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier_q.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier_q, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, selftuning, optimizer, epoch, args) # update lr lr_scheduler.step() # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier_q, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier_q.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, selftuning: SelfTuning, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') pgc_losses_labeled = AverageMeter('Pgc Loss (Labeled Data)', ':3.2f') pgc_losses_unlabeled = AverageMeter('Pgc Loss (Unlabeled Data)', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, pgc_losses_labeled, pgc_losses_unlabeled, cls_accs], prefix="Epoch: [{}]".format(epoch)) # define loss functions criterion_kl = nn.KLDivLoss(reduction='batchmean').to(device) # switch to train mode selftuning.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (l_q, l_k), labels_l = next(labeled_train_iter) (u_q, u_k), _ = next(unlabeled_train_iter) l_q, l_k = l_q.to(device), l_k.to(device) u_q, u_k = u_q.to(device), u_k.to(device) labels_l = labels_l.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output pgc_logits_labeled, pgc_labels_labeled, y_l = selftuning(l_q, l_k, labels_l) # cross entropy loss cls_loss = F.cross_entropy(y_l, labels_l) # pgc loss on labeled samples pgc_loss_labeled = criterion_kl(pgc_logits_labeled, pgc_labels_labeled) (cls_loss + pgc_loss_labeled).backward() # pgc loss on unlabeled samples _, y_pred = selftuning.encoder_q(u_q) _, pseudo_labels = torch.max(y_pred, dim=1) pgc_logits_unlabeled, pgc_labels_unlabeled, _ = selftuning(u_q, u_k, pseudo_labels) pgc_loss_unlabeled = criterion_kl(pgc_logits_unlabeled, pgc_labels_unlabeled) pgc_loss_unlabeled.backward() # compute gradient and do SGD step optimizer.step() # measure accuracy and record loss cls_losses.update(cls_loss.item(), batch_size) pgc_losses_labeled.update(pgc_loss_labeled.item(), batch_size) pgc_losses_unlabeled.update(pgc_loss_unlabeled.item(), batch_size) loss = cls_loss + pgc_loss_labeled + pgc_loss_unlabeled losses.update(loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Self Tuning for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--projection-dim', default=1024, type=int, help='dimension of projection head') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--T', default=0.07, type=float, help="temperature (default: 0.07)") parser.add_argument('--K', default=32, type=int, help="queue size (default: 32)") parser.add_argument('--m', default=0.999, type=float, help="momentum coefficient (default: 0.999)") parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('--milestones', default=[12, 24, 36, 48], type=int, nargs='+', help='epochs to decay learning rate') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run (default: 60)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='self_tuning', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/self_tuning.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/self_tuning/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/self_tuning/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.01 --finetune --seed 0 --log logs/self_tuning/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.003 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.01 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python self_tuning.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --seed 0 --log logs/self_tuning_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/uda.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader import utils from tllib.self_training.uda import StrongWeakConsistencyLoss from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code weak_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, norm_mean=args.norm_mean, norm_std=args.norm_std) strong_augment = utils.get_train_transform(args.train_resizing, random_horizontal_flip=True, auto_augment=args.auto_augment, norm_mean=args.norm_mean, norm_std=args.norm_std) labeled_train_transform = MultipleApply([weak_augment, strong_augment]) unlabeled_train_transform = MultipleApply([weak_augment, strong_augment]) val_transform = utils.get_val_transform(args.val_resizing, norm_mean=args.norm_mean, norm_std=args.norm_std) print('labeled_train_transform: ', labeled_train_transform) print('unlabeled_train_transform: ', unlabeled_train_transform) print('val_transform:', val_transform) labeled_train_dataset, unlabeled_train_dataset, val_dataset = \ utils.get_dataset(args.data, args.num_samples_per_class, args.root, labeled_train_transform, val_transform, unlabeled_train_transform=unlabeled_train_transform, seed=args.seed) print("labeled_dataset_size: ", len(labeled_train_dataset)) print('unlabeled_dataset_size: ', len(unlabeled_train_dataset)) print("val_dataset_size: ", len(val_dataset)) labeled_train_loader = DataLoader(labeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) unlabeled_train_loader = DataLoader(unlabeled_train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) labeled_train_iter = ForeverDataIterator(labeled_train_loader) unlabeled_train_iter = ForeverDataIterator(unlabeled_train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, pretrained_checkpoint=args.pretrained_backbone) num_classes = labeled_train_dataset.num_classes pool_layer = nn.Identity() if args.no_pool else None classifier = utils.ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, pool_layer=pool_layer, finetune=args.finetune).to(device) print(classifier) # define optimizer and lr scheduler if args.lr_scheduler == 'exp': optimizer = SGD(classifier.get_parameters(), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) else: optimizer = SGD(classifier.get_parameters(base_lr=args.lr), args.lr, momentum=0.9, weight_decay=args.wd, nesterov=True) lr_scheduler = utils.get_cosine_scheduler_with_warmup(optimizer, args.epochs * args.iters_per_epoch) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) print(acc1) return # start training best_acc1 = 0.0 best_avg = 0.0 for epoch in range(args.epochs): # print lr print(lr_scheduler.get_lr()) # train for one epoch train(labeled_train_iter, unlabeled_train_iter, classifier, optimizer, lr_scheduler, epoch, args) # evaluate on validation set acc1, avg = utils.validate(val_loader, classifier, args, device, num_classes) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) best_avg = max(avg, best_avg) print("best_acc1 = {:3.1f}".format(best_acc1)) print('best_avg = {:3.1f}'.format(best_avg)) logger.close() def train(labeled_train_iter: ForeverDataIterator, unlabeled_train_iter: ForeverDataIterator, model, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') con_losses = AverageMeter('Con Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, con_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) consistency_criterion = StrongWeakConsistencyLoss(args.threshold, args.T).to(device) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) (x_u, x_u_strong), _ = next(unlabeled_train_iter) x_u = x_u.to(device) x_u_strong = x_u_strong.to(device) # measure data loading time data_time.update(time.time() - end) # clear grad optimizer.zero_grad() # compute output # cross entropy loss y_l = model(x_l) y_l_strong = model(x_l_strong) cls_loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) cls_loss.backward() # consistency loss with torch.no_grad(): y_u = model(x_u) y_u_strong = model(x_u_strong) con_loss = args.trade_off_con * consistency_criterion(y_u_strong, y_u) con_loss.backward() # measure accuracy and record loss loss = cls_loss + con_loss losses.update(loss.item(), batch_size) cls_losses.update(cls_loss.item(), batch_size) con_losses.update(con_loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # compute gradient and do SGD step optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='UDA for Semi Supervised Learning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA', help='dataset: ' + ' | '.join(utils.get_dataset_names())) parser.add_argument('--num-samples-per-class', default=4, type=int, help='number of labeled samples per class') parser.add_argument('--train-resizing', default='default', type=str) parser.add_argument('--val-resizing', default='default', type=str) parser.add_argument('--norm-mean', default=(0.485, 0.456, 0.406), type=float, nargs='+', help='normalization mean') parser.add_argument('--norm-std', default=(0.229, 0.224, 0.225), type=float, nargs='+', help='normalization std') parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, help='AutoAugment policy (default: rand-m10-n2-mstd2)') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--bottleneck-dim', default=1024, type=int, help='dimension of bottleneck') parser.add_argument('--no-pool', action='store_true', default=False, help='no pool layer after the feature extractor') parser.add_argument('--pretrained-backbone', default=None, type=str, help="pretrained checkpoint of the backbone " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--finetune', action='store_true', default=False, help='whether to use 10x smaller lr for backbone') # training parameters parser.add_argument('--trade-off-cls-strong', default=0.1, type=float, help='the trade-off hyper-parameter of cls loss on strong augmented labeled data') parser.add_argument('--trade-off-con', default=1, type=float, help='the trade-off hyper-parameter of consistency loss') parser.add_argument('--threshold', default=0.7, type=float, help='confidence threshold') parser.add_argument('--T', default=0.85, type=float, help='temperature') parser.add_argument('-b', '--batch-size', default=32, type=int, metavar='N', help='mini-batch size (default: 32)') parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, metavar='LR', dest='lr', help='initial learning rate') parser.add_argument('--lr-scheduler', default='exp', type=str, choices=['exp', 'cos'], help='learning rate decay strategy') parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') parser.add_argument('--wd', '--weight-decay', default=5e-4, type=float, metavar='W', help='weight decay (default:5e-4)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--epochs', default=60, type=int, metavar='N', help='number of total epochs to run (default: 60)') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='number of iterations per epoch (default: 500)') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training ') parser.add_argument("--log", default='uda', type=str, help="where to save logs, checkpoints and debugging images") parser.add_argument("--phase", default='train', type=str, choices=['train', 'test'], help="when phase is 'test', only test the model") args = parser.parse_args() main(args) ================================================ FILE: examples/semi_supervised_learning/image_classification/uda.sh ================================================ #!/usr/bin/env bash # ImageNet Supervised Pretrain (ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python uda.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.95 --seed 0 --log logs/uda/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python uda.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python uda.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python uda.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python uda.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python uda.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.8 --seed 0 --log logs/uda/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python uda.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python uda.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --lr 0.003 --finetune --threshold 0.7 --seed 0 --log logs/uda/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python uda.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --lr 0.001 --finetune --threshold 0.7 --seed 0 --log logs/uda/caltech_4_labels_per_class # ImageNet Unsupervised Pretrain (MoCov2, ResNet50) # ====================================================================================================================== # Food 101 CUDA_VISIBLE_DEVICES=0 python uda.py data/food101 -d Food101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/food101_4_labels_per_class # ====================================================================================================================== # CIFAR 10 CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar10 -d CIFAR10 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.4912 0.4824 0.4467 --norm-std 0.2471 0.2435 0.2616 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cifar10_4_labels_per_class # ====================================================================================================================== # CIFAR 100 CUDA_VISIBLE_DEVICES=0 python uda.py data/cifar100 -d CIFAR100 --train-resizing 'cifar' --val-resizing 'cifar' \ --norm-mean 0.5071 0.4867 0.4408 --norm-std 0.2675 0.2565 0.2761 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cifar100_4_labels_per_class # ====================================================================================================================== # CUB 200 CUDA_VISIBLE_DEVICES=0 python uda.py data/cub200 -d CUB200 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/cub200_4_labels_per_class # ====================================================================================================================== # Aircraft CUDA_VISIBLE_DEVICES=0 python uda.py data/aircraft -d Aircraft --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/aircraft_4_labels_per_class # ====================================================================================================================== # StanfordCars CUDA_VISIBLE_DEVICES=0 python uda.py data/cars -d StanfordCars --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/car_4_labels_per_class # ====================================================================================================================== # SUN397 CUDA_VISIBLE_DEVICES=0 python uda.py data/sun397 -d SUN397 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/sun_4_labels_per_class # ====================================================================================================================== # DTD CUDA_VISIBLE_DEVICES=0 python uda.py data/dtd -d DTD --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/dtd_4_labels_per_class # ====================================================================================================================== # Oxford Pets CUDA_VISIBLE_DEVICES=0 python uda.py data/pets -d OxfordIIITPets --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/pets_4_labels_per_class # ====================================================================================================================== # Oxford Flowers CUDA_VISIBLE_DEVICES=0 python uda.py data/flowers -d OxfordFlowers102 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.8 --seed 0 --log logs/uda_moco_pretrain/flowers_4_labels_per_class # ====================================================================================================================== # Caltech 101 CUDA_VISIBLE_DEVICES=0 python uda.py data/caltech101 -d Caltech101 --num-samples-per-class 4 -a resnet50 \ --pretrained-backbone checkpoints/moco_v2_800ep_backbone.pth \ --lr 0.001 --finetune --lr-scheduler cos --threshold 0.7 --seed 0 --log logs/uda_moco_pretrain/caltech_4_labels_per_class ================================================ FILE: examples/semi_supervised_learning/image_classification/utils.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import math import sys import time from PIL import Image import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.optim.lr_scheduler import LambdaLR from torch.utils.data.dataset import Subset, ConcatDataset import torchvision.transforms as T import timm from timm.data.auto_augment import auto_augment_transform, rand_augment_transform sys.path.append('../../..') from tllib.modules.classifier import Classifier import tllib.vision.datasets as datasets import tllib.vision.models as models from tllib.vision.transforms import ResizeImage from tllib.utils.metric import accuracy, ConfusionMatrix from tllib.utils.meter import AverageMeter, ProgressMeter def get_model_names(): return sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) + timm.list_models() def get_model(model_name, pretrained=True, pretrained_checkpoint=None): if model_name in models.__dict__: # load models from common.vision.models backbone = models.__dict__[model_name](pretrained=pretrained) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=pretrained) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() if pretrained_checkpoint: print("=> loading pre-trained model from '{}'".format(pretrained_checkpoint)) pretrained_dict = torch.load(pretrained_checkpoint) backbone.load_state_dict(pretrained_dict, strict=False) return backbone def get_dataset_names(): return sorted( name for name in datasets.__dict__ if not name.startswith("__") and callable(datasets.__dict__[name]) ) def get_dataset(dataset_name, num_samples_per_class, root, labeled_train_transform, val_transform, unlabeled_train_transform=None, seed=0): if unlabeled_train_transform is None: unlabeled_train_transform = labeled_train_transform if dataset_name == 'OxfordFlowers102': dataset = datasets.__dict__[dataset_name] base_dataset = dataset(root=root, split='train', transform=labeled_train_transform, download=True) # create labeled and unlabeled splits labeled_idxes, unlabeled_idxes = x_u_split(num_samples_per_class, base_dataset.num_classes, base_dataset.targets, seed=seed) # labeled subset labeled_train_dataset = Subset(base_dataset, labeled_idxes) labeled_train_dataset.num_classes = base_dataset.num_classes # unlabeled subset base_dataset = dataset(root=root, split='train', transform=unlabeled_train_transform, download=True) unlabeled_train_dataset = ConcatDataset([ Subset(base_dataset, unlabeled_idxes), dataset(root=root, split='validation', download=True, transform=unlabeled_train_transform) ]) val_dataset = dataset(root=root, split='test', download=True, transform=val_transform) else: dataset = datasets.__dict__[dataset_name] base_dataset = dataset(root=root, split='train', transform=labeled_train_transform, download=True) # create labeled and unlabeled splits labeled_idxes, unlabeled_idxes = x_u_split(num_samples_per_class, base_dataset.num_classes, base_dataset.targets, seed=seed) # labeled subset labeled_train_dataset = Subset(base_dataset, labeled_idxes) labeled_train_dataset.num_classes = base_dataset.num_classes # unlabeled subset base_dataset = dataset(root=root, split='train', transform=unlabeled_train_transform, download=True) unlabeled_train_dataset = Subset(base_dataset, unlabeled_idxes) val_dataset = dataset(root=root, split='test', download=True, transform=val_transform) return labeled_train_dataset, unlabeled_train_dataset, val_dataset def x_u_split(num_samples_per_class, num_classes, labels, seed): """ Construct labeled and unlabeled subsets, where the labeled subset is class balanced. Note that the resulting subsets are **deterministic** with the same random seed. """ labels = np.array(labels) assert num_samples_per_class * num_classes <= len(labels) random_state = np.random.RandomState(seed) # labeled subset labeled_idxes = [] for i in range(num_classes): ith_class_idxes = np.where(labels == i)[0] ith_class_idxes = random_state.choice(ith_class_idxes, num_samples_per_class, False) labeled_idxes.extend(ith_class_idxes) # unlabeled subset unlabeled_idxes = [i for i in range(len(labels)) if i not in labeled_idxes] return labeled_idxes, unlabeled_idxes def get_train_transform(resizing='default', random_horizontal_flip=True, auto_augment=None, norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)): if resizing == 'default': transform = T.RandomResizedCrop(224, scale=(0.2, 1.)) elif resizing == 'cifar': transform = T.Compose([ T.RandomCrop(size=32, padding=4, padding_mode='reflect'), ResizeImage(224) ]) else: raise NotImplementedError(resizing) transforms = [transform] if random_horizontal_flip: transforms.append(T.RandomHorizontalFlip()) if auto_augment: aa_params = dict( translate_const=int(224 * 0.45), img_mean=tuple([min(255, round(255 * x)) for x in norm_mean]), interpolation=Image.BILINEAR ) if auto_augment.startswith('rand'): transforms.append(rand_augment_transform(auto_augment, aa_params)) else: transforms.append(auto_augment_transform(auto_augment, aa_params)) transforms.extend([ T.ToTensor(), T.Normalize(mean=norm_mean, std=norm_std) ]) return T.Compose(transforms) def get_val_transform(resizing='default', norm_mean=(0.485, 0.456, 0.406), norm_std=(0.229, 0.224, 0.225)): if resizing == 'default': transform = T.Compose([ ResizeImage(256), T.CenterCrop(224), ]) elif resizing == 'cifar': transform = ResizeImage(224) else: raise NotImplementedError(resizing) return T.Compose([ transform, T.ToTensor(), T.Normalize(mean=norm_mean, std=norm_std) ]) def convert_dataset(dataset): """ Converts a dataset which returns (img, label) pairs into one that returns (index, img, label) triplets. """ class DatasetWrapper: def __init__(self): self.dataset = dataset def __getitem__(self, index): return index, self.dataset[index] def __len__(self): return len(self.dataset) return DatasetWrapper() class ImageClassifier(Classifier): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) bottleneck[0].weight.data.normal_(0, 0.005) bottleneck[0].bias.data.fill_(0.1) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) def forward(self, x: torch.Tensor): f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) predictions = self.head(f) return predictions def get_cosine_scheduler_with_warmup(optimizer, T_max, num_cycles=7. / 16., num_warmup_steps=0, last_epoch=-1): """ Cosine learning rate scheduler from `FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence (NIPS 2020) `_. Args: optimizer (Optimizer): Wrapped optimizer. T_max (int): Maximum number of iterations. num_cycles (float): A scalar that controls the shape of cosine function. Default: 7/16. num_warmup_steps (int): Number of iterations to warm up. Default: 0. last_epoch (int): The index of last epoch. Default: -1. """ def _lr_lambda(current_step): if current_step < num_warmup_steps: _lr = float(current_step) / float(max(1, num_warmup_steps)) else: num_cos_steps = float(current_step - num_warmup_steps) num_cos_steps = num_cos_steps / float(max(1, T_max - num_warmup_steps)) _lr = max(0.0, math.cos(math.pi * num_cycles * num_cos_steps)) return _lr return LambdaLR(optimizer, _lr_lambda, last_epoch) def validate(val_loader, model, args, device, num_classes): batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') top5 = AverageMeter('Acc@5', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1, top5], prefix='Test: ') # switch to evaluate mode model.eval() confmat = ConfusionMatrix(num_classes) with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output = model(images) loss = F.cross_entropy(output, target) # measure accuracy and record loss acc1, acc5 = accuracy(output, target, topk=(1, 5)) confmat.update(target, output.argmax(1)) losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) top5.update(acc5.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' .format(top1=top1, top5=top5)) acc_global, acc_per_class, iu = confmat.compute() mean_cls_acc = acc_per_class.mean().item() * 100 print(' * Mean Cls {:.3f}'.format(mean_cls_acc)) return top1.avg, mean_cls_acc def empirical_risk_minimization(labeled_train_iter, model, optimizer, lr_scheduler, epoch, args, device): batch_time = AverageMeter('Time', ':2.2f') data_time = AverageMeter('Data', ':2.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() batch_size = args.batch_size for i in range(args.iters_per_epoch): (x_l, x_l_strong), labels_l = next(labeled_train_iter) x_l = x_l.to(device) x_l_strong = x_l_strong.to(device) labels_l = labels_l.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_l = model(x_l) y_l_strong = model(x_l_strong) # cross entropy loss on both weak augmented and strong augmented samples loss = F.cross_entropy(y_l, labels_l) + args.trade_off_cls_strong * F.cross_entropy(y_l_strong, labels_l) # measure accuracy and record loss losses.update(loss.item(), batch_size) cls_acc = accuracy(y_l, labels_l)[0] cls_accs.update(cls_acc.item(), batch_size) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) ================================================ FILE: examples/task_adaptation/image_classification/README.md ================================================ # Task Adaptation for Image Classification ## Installation Example scripts support all models in [PyTorch-Image-Models](https://github.com/rwightman/pytorch-image-models). You need to install timm to use PyTorch-Image-Models. ``` pip install timm ``` ## Dataset Following datasets can be downloaded automatically: - [CUB200](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html) - [StanfordCars](https://ai.stanford.edu/~jkrause/cars/car_dataset.html) - [Aircraft](https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/) - [StanfordDogs](http://vision.stanford.edu/aditya86/ImageNetDogs/) - [OxfordIIITPets](https://www.robots.ox.ac.uk/~vgg/data/pets/) - [OxfordFlowers102](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) - [DTD](https://www.robots.ox.ac.uk/~vgg/data/dtd/index.html) - [PatchCamelyon](https://patchcamelyon.grand-challenge.org/) - [EuroSAT](https://github.com/phelber/eurosat) You need to prepare following datasets manually if you want to use them: - [Retinopathy](https://www.kaggle.com/c/diabetic-retinopathy-detection/data) - [Resisc45](http://www.escience.cn/people/JunweiHan/NWPU-RESISC45.html) and prepare them following [Documentation for Retinopathy](/common/vision/datasets/retinopathy.py) and [Resisc45](/common/vision/datasets/resisc45.py). ## Supported Methods Supported methods include: - [Explicit inductive bias for transfer learning with convolutional networks (L2-SP, ICML 2018)](https://arxiv.org/abs/1802.01483) - [Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning (BSS, NIPS 2019)](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf) - [DEep Learning Transfer using Fea- ture Map with Attention for convolutional networks (DELTA, ICLR 2019)](https://openreview.net/pdf?id=rkgbwsAcYm) - [Co-Tuning for Transfer Learning (Co-Tuning, NIPS 2020)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/co-tuning-for-transfer-learning-nips20.pdf) - [Stochastic Normalization (StochNorm, NIPS 2020)](https://papers.nips.cc/paper/2020/file/bc573864331a9e42e4511de6f678aa83-Paper.pdf) - [Learning Without Forgetting (LWF, ECCV 2016)](https://arxiv.org/abs/1606.09282) - [Bi-tuning of Pre-trained Representations (Bi-Tuning)](https://arxiv.org/abs/2011.06182?utm_source=feedburner&utm_medium=feed&utm_campaign=Feed%3A+arxiv%2FQSXk+%28ExcitingAds%21+cs+updates+on+arXiv.org%29) ## Experiment and Results We follow the common practice in the community as described in [Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning (BSS, NIPS 2019)](https://proceedings.neurips.cc/paper/2019/file/c6bff625bdb0393992c9d4db0c6bbe45-Paper.pdf) . Training iterations and data augmentations are kept the same for different task-adaptation methods for a fair comparison. Hyper-parameters of each method are selected by the performance on target validation data. ### Fine-tune the supervised pre-trained model The shell files give the script to reproduce the supervised pretrained benchmarks with specified hyper-parameters. For example, if you want to use vanilla fine-tune on CUB200, use the following script ```shell script # Fine-tune ResNet50 on CUB200. # Assume you have put the datasets under the path `data/cub200`, # or you are glad to download the datasets automatically from the Internet to this path CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/erm/cub200_100 ``` #### Vision Benchmark on ResNet-50 | | Food101 | CIFAR10 | CIFAR100 | SUN397 | Standford Cars | FGVC Aircraft | DTD | Oxford-IIIT Pets | Caltech-101 | Oxford 102 Flowers | average | |-----------------|---------|---------|----------|--------|----------------|----------------|-------|------------------|----------------|--------------------|---------| | Accuracy metric | top1 | top1 | top1 | top1 | top1 | mean per-class | top-1 | mean per-class | mean per-class | mean per-class | | | Baseline | 85.1 | 96.6 | 84.1 | 63.7 | 87.8 | 80.1 | 70.8 | 93.2 | 91.1 | 93.0 | 84.6 | | LWF | 83.9 | 96.5 | 83.6 | 64.1 | 87.4 | 82.2 | 72.2 | 94.0 | 89.8 | 92.9 | 84.7 | | DELTA | 83.8 | 95.9 | 83.7 | 64.5 | 88.1 | 82.3 | 72.2 | 94.2 | 90.1 | 93.1 | 84.8 | | BSS | 85.0 | 96.6 | 84.2 | 63.5 | 88.4 | 81.8 | 70.2 | 93.3 | 91.6 | 92.7 | 84.7 | | StochNorm | 85.0 | 96.8 | 83.9 | 63.0 | 87.7 | 81.5 | 71.3 | 93.6 | 90.5 | 92.9 | 84.6 | | Bi-Tuning | 85.7 | 97.1 | 84.3 | 64.2 | 90.3 | 84.8 | 70.6 | 93.5 | 91.5 | 94.5 | 85.7 | #### CUB-200-2011 on ResNet-50 (Supervised Pre-trained) | CUB200 | 15% | 30% | 50% | 100% | Avg | |-----------|------|------|------|------|------| | ERM | 51.2 | 64.6 | 74.6 | 81.8 | 68.1 | | lwf | 56.7 | 66.8 | 73.4 | 81.5 | 69.6 | | BSS | 53.4 | 66.7 | 76.0 | 82.0 | 69.5 | | delta | 54.8 | 67.3 | 76.3 | 82.3 | 70.2 | | StochNorm | 54.8 | 66.8 | 75.8 | 82.2 | 69.9 | | Co-tuning | 57.6 | 70.1 | 77.3 | 82.5 | 71.9 | | bi-tuning | 55.8 | 69.3 | 77.2 | 83.1 | 71.4 | #### Stanford Cars on ResNet-50 (Supervised Pre-trained) | Standford Cars | 15% | 30% | 50% | 100% | Avg | |----------------|------|------|------|------|------| | ERM | 41.1 | 65.9 | 78.4 | 87.8 | 68.3 | | lwf | 44.9 | 67.0 | 77.6 | 87.5 | 69.3 | | BSS | 43.3 | 67.6 | 79.6 | 88.0 | 69.6 | | delta | 45.0 | 68.4 | 79.6 | 88.4 | 70.4 | | StochNorm | 44.4 | 68.1 | 79.3 | 87.9 | 69.9 | | Co-tuning | 49.0 | 70.6 | 81.9 | 89.1 | 72.7 | | bi-tuning | 48.3 | 72.8 | 83.3 | 90.2 | 73.7 | #### FGVC Aircraft on ResNet-50 (Supervised Pre-trained) | FGVC Aircraft | 15% | 30% | 50% | 100% | Avg | |---------------|------|------|------|------|------| | ERM | 41.6 | 57.8 | 68.7 | 80.2 | 62.1 | | lwf | 44.1 | 60.6 | 68.7 | 82.4 | 64.0 | | BSS | 43.6 | 59.5 | 69.6 | 81.2 | 63.5 | | delta | 44.4 | 61.9 | 71.4 | 82.7 | 65.1 | | StochNorm | 44.3 | 60.6 | 70.1 | 81.5 | 64.1 | | Co-tuning | 45.9 | 61.2 | 71.3 | 82.2 | 65.2 | | bi-tuning | 47.2 | 64.3 | 73.7 | 84.3 | 67.4 | ### Fine-tune the unsupervised pre-trained model Take MoCo as an example. 1. Download MoCo pretrained checkpoints from https://github.com/facebookresearch/moco 2. Convert the format of the MoCo checkpoints to the standard format of pytorch ```shell mkdir checkpoints python convert_moco_to_pretrained.py checkpoints/moco_v1_200ep_pretrain.pth.tar checkpoints/moco_v1_200ep_backbone.pth checkpoints/moco_v1_200ep_fc.pth ``` 3. Start training ```shell CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth ``` #### CUB-200-2011 on ResNet-50 (MoCo Pre-trained) | CUB200 | 15% | 30% | 50% | 100% | Avg | |-----------|------|------|------|------|------| | ERM | 28.0 | 48.2 | 62.7 | 75.6 | 53.6 | | lwf | 28.8 | 50.1 | 62.8 | 76.2 | 54.5 | | BSS | 30.9 | 50.3 | 63.7 | 75.8 | 55.2 | | delta | 27.9 | 51.4 | 65.9 | 74.6 | 55.0 | | StochNorm | 20.8 | 44.9 | 60.1 | 72.8 | 49.7 | | Co-tuning | 29.1 | 50.1 | 63.8 | 75.9 | 54.7 | | bi-tuning | 32.4 | 51.8 | 65.7 | 76.1 | 56.5 | #### Stanford Cars on ResNet-50 (MoCo Pre-trained) | Standford Cars | 15% | 30% | 50% | 100% | Avg | |----------------|------|------|------|------|------| | ERM | 42.5 | 71.2 | 83.0 | 90.1 | 71.7 | | lwf | 44.2 | 71.7 | 82.9 | 90.5 | 72.3 | | BSS | 45.0 | 71.5 | 83.8 | 90.1 | 72.6 | | delta | 45.9 | 72.9 | 82.5 | 88.9 | 72.6 | | StochNorm | 40.3 | 66.2 | 78.0 | 86.2 | 67.7 | | Co-tuning | 44.2 | 72.6 | 83.3 | 90.3 | 72.6 | | bi-tuning | 45.6 | 72.8 | 83.2 | 90.8 | 73.1 | #### FGVC Aircraft on ResNet-50 (MoCo Pre-trained) | FGVC Aircraft | 15% | 30% | 50% | 100% | Avg | |---------------|------|------|------|------|------| | ERM | 45.8 | 67.6 | 78.8 | 88.0 | 70.1 | | lwf | 48.5 | 68.5 | 78.0 | 87.9 | 70.7 | | BSS | 47.7 | 69.1 | 79.2 | 88.0 | 71.0 | | delta | - | - | - | - | - | | StochNorm | 45.4 | 68.8 | 76.7 | 86.1 | 69.3 | | Co-tuning | 48.2 | 68.5 | 78.7 | 87.3 | 70.7 | | bi-tuning | 46.4 | 69.6 | 79.4 | 87.9 | 70.8 | ## Citation If you use these methods in your research, please consider citing. ``` @inproceedings{LWF, author = {Zhizhong Li and Derek Hoiem}, title = {Learning without Forgetting}, booktitle={ECCV}, year = {2016}, } @inproceedings{L2SP, title={Explicit inductive bias for transfer learning with convolutional networks}, author={Xuhong, LI and Grandvalet, Yves and Davoine, Franck}, booktitle={ICML}, year={2018}, } @inproceedings{BSS, title={Catastrophic forgetting meets negative transfer: Batch spectral shrinkage for safe transfer learning}, author={Chen, Xinyang and Wang, Sinan and Fu, Bo and Long, Mingsheng and Wang, Jianmin}, booktitle={NeurIPS}, year={2019} } @inproceedings{DELTA, title={Delta: Deep learning transfer using feature map with attention for convolutional networks}, author={Li, Xingjian and Xiong, Haoyi and Wang, Hanchao and Rao, Yuxuan and Liu, Liping and Chen, Zeyu and Huan, Jun}, booktitle={ICLR}, year={2019} } @inproceedings{StocNorm, title={Stochastic Normalization}, author={Kou, Zhi and You, Kaichao and Long, Mingsheng and Wang, Jianmin}, booktitle={NeurIPS}, year={2020} } @inproceedings{CoTuning, title={Co-Tuning for Transfer Learning}, author={You, Kaichao and Kou, Zhi and Long, Mingsheng and Wang, Jianmin}, booktitle={NeurIPS}, year={2020} } @article{BiTuning, title={Bi-tuning of Pre-trained Representations}, author={Zhong, Jincheng and Wang, Ximei and Kou, Zhi and Wang, Jianmin and Long, Mingsheng}, journal={arXiv preprint arXiv:2011.06182}, year={2020} } ``` ================================================ FILE: examples/task_adaptation/image_classification/bi_tuning.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import utils from tllib.vision.transforms import MultipleApply from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger from tllib.regularization.bi_tuning import Classifier, BiTuning device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_augmentation = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter) val_transform = utils.get_val_transform(args.val_resizing) train_transform = MultipleApply([train_augmentation, train_augmentation]) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform, val_transform, args.sample_rate, args.num_samples_per_classes) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_iter = ForeverDataIterator(train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset))) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone_q = utils.get_model(args.arch, args.pretrained) pool_layer = nn.Identity() if args.no_pool else None classifier_q = Classifier(backbone_q, num_classes, pool_layer=pool_layer, projection_dim=args.projection_dim, finetune=args.finetune) if args.pretrained_fc: print("=> loading pre-trained fc from '{}'".format(args.pretrained_fc)) pretrained_fc_dict = torch.load(args.pretrained_fc) classifier_q.projector.load_state_dict(pretrained_fc_dict, strict=False) classifier_q = classifier_q.to(device) backbone_k = utils.get_model(args.arch) classifier_k = Classifier(backbone_k, num_classes, pool_layer=pool_layer).to(device) bituning = BiTuning(classifier_q, classifier_k, num_classes, K=args.K, m=args.m, T=args.T) # define optimizer and lr scheduler optimizer = SGD(classifier_q.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier_q.load_state_dict(checkpoint) acc1 = utils.validate(val_loader, classifier_q, args, device) print(acc1) return # start training best_acc1 = 0.0 for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, bituning, optimizer, epoch, args) lr_scheduler.step() # evaluate on validation set acc1 = utils.validate(val_loader, classifier_q, args, device) # remember best acc@1 and save checkpoint torch.save(classifier_q.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) logger.close() def train(train_iter: ForeverDataIterator, bituning: BiTuning, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') cls_losses = AverageMeter('Cls Loss', ':3.2f') contrastive_losses = AverageMeter('Contrastive Loss', ':3.2f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_losses, contrastive_losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) classifier_criterion = torch.nn.CrossEntropyLoss().to(device) contrastive_criterion = torch.nn.KLDivLoss(reduction='batchmean').to(device) # switch to train mode bituning.train() end = time.time() for i in range(args.iters_per_epoch): x, labels = next(train_iter) img_q, img_k = x[0], x[1] img_q = img_q.to(device) img_k = img_k.to(device) labels = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, logits_z, logits_y, bituning_labels = bituning(img_q, img_k, labels) cls_loss = classifier_criterion(y, labels) contrastive_loss_z = contrastive_criterion(logits_z, bituning_labels) contrastive_loss_y = contrastive_criterion(logits_y, bituning_labels) contrastive_loss = (contrastive_loss_z + contrastive_loss_y) loss = cls_loss + contrastive_loss * args.trade_off # measure accuracy and record loss losses.update(loss.item(), x[0].size(0)) cls_losses.update(cls_loss.item(), x[0].size(0)) contrastive_losses.update(contrastive_loss.item(), x[0].size(0)) cls_acc = accuracy(y, labels)[0] cls_accs.update(cls_acc.item(), x[0].size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Bi-tuning for Finetuning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training') parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor. Used in models such as ViT.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--pretrained-fc', default=None, help="pretrained checkpoint of the fc. " "(default: None)") parser.add_argument('--T', default=0.07, type=float, help="temperature. (default: 0.07)") parser.add_argument('--K', type=int, default=40, help="queue size. (default: 40)") parser.add_argument('--m', type=float, default=0.999, help="momentum coefficient. (default: 0.999)") parser.add_argument('--projection-dim', type=int, default=128, help="dimension of the projection head. (default: 128)") parser.add_argument('--trade-off', type=float, default=1.0, help="trade-off parameters. (default: 1.0)") # training parameters parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam']) parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='bi_tuning', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/task_adaptation/image_classification/bi_tuning.sh ================================================ #!/usr/bin/env bash # Supervised Pretraining # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --finetune --seed 0 --log logs/bi_tuning/cub200_100 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 50 --finetune --seed 0 --log logs/bi_tuning/cub200_50 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 30 --finetune --seed 0 --log logs/bi_tuning/cub200_30 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 15 --finetune --seed 0 --log logs/bi_tuning/cub200_15 # Standford Cars CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 100 --finetune --seed 0 --log logs/bi_tuning/car_100 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 50 --finetune --seed 0 --log logs/bi_tuning/car_50 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 30 --finetune --seed 0 --log logs/bi_tuning/car_30 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 15 --finetune --seed 0 --log logs/bi_tuning/car_15 # Aircrafts CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/bi_tuning/aircraft_100 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/bi_tuning/aircraft_50 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/bi_tuning/aircraft_30 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/bi_tuning/aircraft_15 # CIFAR10 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/bi_tuning/cifar10/1e-2 --lr 1e-2 # CIFAR100 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/bi_tuning/cifar100/1e-2 --lr 1e-2 # Flowers CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/bi_tuning/oxford_flowers102/1e-2 --lr 1e-2 # Pets CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/bi_tuning/oxford_pet/1e-2 --lr 1e-2 # DTD CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/dtd -d DTD --seed 0 --finetune --log logs/bi_tuning/dtd/1e-2 --lr 1e-2 # caltech101 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/bi_tuning/caltech101/lr_1e-3 --lr 1e-3 # SUN397 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/bi_tuning/sun397/lr_1e-2 --lr 1e-2 # Food 101 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/food-101 -d Food101 --seed 0 --finetune --log logs/bi_tuning/food-101/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/bi_tuning/stanford_cars/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/bi_tuning/aircraft/lr_1e-2 --lr 1e-2 # MoCo (Unsupervised Pretraining) # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Standford Cars CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Aircrafts CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bi_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bi_tuning/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth ================================================ FILE: examples/task_adaptation/image_classification/bss.py ================================================ """ @author: Yifei Ji, Junguang Jiang @contact: jiyf990330@163.com, JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.regularization.bss import BatchSpectralShrinkage from tllib.modules.classifier import Classifier from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform, val_transform, args.sample_rate, args.num_samples_per_classes) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_iter = ForeverDataIterator(train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset))) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, args.pretrained) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device) bss_module = BatchSpectralShrinkage(k=args.k) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1 = utils.validate(val_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0.0 for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, bss_module, optimizer, epoch, args) lr_scheduler.step() # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model: Classifier, bss_module, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels = next(train_iter) x = x.to(device) label = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, f = model(x) cls_loss = F.cross_entropy(y, label) bss_loss = bss_module(f) loss = cls_loss + args.trade_off * bss_loss cls_acc = accuracy(y, label)[0] losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='BSS for Finetuning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training') parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor. Used in models such as ViT.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('-k', '--k', default=1, type=int, metavar='N', help='hyper-parameter for BSS loss') parser.add_argument('--trade-off', default=0.001, type=float, metavar='P', help='trade-off weight of BSS loss') # training parameters parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='bss', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/task_adaptation/image_classification/bss.sh ================================================ #!/usr/bin/env bash # Supervised Pretraining # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/bss/cub200_100 CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/bss/cub200_50 CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/bss/cub200_30 CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/bss/cub200_15 # Standford Cars CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/bss/car_100 CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/bss/car_50 CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/bss/car_30 CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/bss/car_15 # Aircrafts CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/bss/aircraft_100 CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/bss/aircraft_50 CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/bss/aircraft_30 CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/bss/aircraft_15 # CIFAR10 CUDA_VISIBLE_DEVICES=0 python bss.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/bss/cifar10/1e-2 --lr 1e-2 # CIFAR100 CUDA_VISIBLE_DEVICES=0 python bss.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/bss/cifar100/1e-2 --lr 1e-2 # Flowers CUDA_VISIBLE_DEVICES=0 python bss.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/bss/oxford_flowers102/1e-2 --lr 1e-2 # Pets CUDA_VISIBLE_DEVICES=0 python bss.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/bss/oxford_pet/1e-2 --lr 1e-2 # DTD CUDA_VISIBLE_DEVICES=0 python bss.py data/dtd -d DTD --seed 0 --finetune --log logs/bss/dtd/1e-2 --lr 1e-2 # caltech101 CUDA_VISIBLE_DEVICES=0 python bss.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/bss/caltech101/lr_1e-3 --lr 1e-3 # SUN397 CUDA_VISIBLE_DEVICES=0 python bss.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/bss/sun397/lr_1e-2 --lr 1e-2 # Food 101 CUDA_VISIBLE_DEVICES=0 python bss.py data/food-101 -d Food101 --seed 0 --finetune --log logs/bss/food-101/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/bss/stanford_cars/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/bss/aircraft/lr_1e-2 --lr 1e-2 # MoCo (Unsupervised Pretraining) # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Standford Cars CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Aircrafts CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python bss.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_bss/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth ================================================ FILE: examples/task_adaptation/image_classification/co_tuning.py ================================================ """ @author: Yifei Ji, Junguang Jiang @contact: jiyf990330@163.com, JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import os import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import torch.nn.functional as F from torch.utils.data import Subset import utils from tllib.regularization.co_tuning import CoTuningLoss, Relationship, Classifier from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.data import ForeverDataIterator import tllib.vision.datasets as datasets device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def get_dataset(dataset_name, root, train_transform, val_transform, sample_rate=100, num_samples_per_classes=None): dataset = datasets.__dict__[dataset_name] if sample_rate < 100: train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True, transform=train_transform) determin_train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True, transform=val_transform) test_dataset = dataset(root=root, split='test', sample_rate=100, download=True, transform=val_transform) num_classes = train_dataset.num_classes else: train_dataset = dataset(root=root, split='train', transform=train_transform) determin_train_dataset = dataset(root=root, split='train', transform=val_transform) test_dataset = dataset(root=root, split='test', transform=val_transform) num_classes = train_dataset.num_classes if num_samples_per_classes is not None: samples = list(range(len(train_dataset))) random.shuffle(samples) samples_len = min(num_samples_per_classes * num_classes, len(train_dataset)) train_dataset = Subset(train_dataset, samples[:samples_len]) determin_train_dataset = Subset(determin_train_dataset, samples[:samples_len]) return train_dataset, determin_train_dataset, test_dataset, num_classes def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, determin_train_dataset, val_dataset, num_classes = get_dataset(args.data, args.root, train_transform, val_transform, args.sample_rate, args.num_samples_per_classes) print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset))) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) determin_train_loader = DataLoader(determin_train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False) train_iter = ForeverDataIterator(train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, args.pretrained) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, head_source=backbone.copy_head(), pool_layer=pool_layer, finetune=args.finetune).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1 = utils.validate(val_loader, classifier, args, device) print(acc1) return # build relationship between source classes and target classes source_classifier = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.head_source) relationship = Relationship(determin_train_loader, source_classifier, device, os.path.join(logger.root, args.relationship)) co_tuning_loss = CoTuningLoss() # start training best_acc1 = 0.0 for epoch in range(args.epochs): # train for one epoch train(train_iter, classifier, optimizer, epoch, relationship, co_tuning_loss, args) lr_scheduler.step() # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD, epoch: int, relationship, co_tuning_loss, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, label_t = next(train_iter) x = x.to(device) label_s = torch.from_numpy(relationship[label_t]).cuda().float() label_t = label_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, y_t = model(x) tgt_loss = F.cross_entropy(y_t, label_t) src_loss = co_tuning_loss(y_s, label_s) loss = tgt_loss + args.trade_off * src_loss # measure accuracy and record loss losses.update(loss.item(), x.size(0)) cls_acc = accuracy(y_t, label_t)[0] cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Co-Tuning for Finetuning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training') parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor. Used in models such as ViT.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--trade-off', default=2.3, type=float, metavar='P', help='the trade-off hyper-parameter for co-tuning loss') parser.add_argument("--relationship", type=str, default='relationship.npy', help="Where to save relationship file.") parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") # training parameters parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='cotuning', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/task_adaptation/image_classification/co_tuning.sh ================================================ #!/usr/bin/env bash # Supervised Pretraining # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/co_tuning/cub200_100 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/co_tuning/cub200_50 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/co_tuning/cub200_30 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/co_tuning/cub200_15 # Standford Cars CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/co_tuning/car_100 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/co_tuning/car_50 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/co_tuning/car_30 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/co_tuning/car_15 # Aircrafts CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/co_tuning/aircraft_100 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/co_tuning/aircraft_50 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/co_tuning/aircraft_30 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/co_tuning/aircraft_15 # CIFAR10 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/co_tuning/cifar10/1e-2 --lr 1e-2 # CIFAR100 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/co_tuning/cifar100/1e-2 --lr 1e-2 # Flowers CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/co_tuning/oxford_flowers102/1e-2 --lr 1e-2 # Pets CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/co_tuning/oxford_pet/1e-2 --lr 1e-2 # DTD CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/dtd -d DTD --seed 0 --finetune --log logs/co_tuning/dtd/1e-2 --lr 1e-2 # caltech101 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/co_tuning/caltech101/lr_1e-3 --lr 1e-3 # SUN397 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/co_tuning/sun397/lr_1e-2 --lr 1e-2 # Food 101 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/food-101 -d Food101 --seed 0 --finetune --log logs/co_tuning/food-101/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/co_tuning/stanford_cars/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/co_tuning/aircraft/lr_1e-2 --lr 1e-2 # MoCo (Unsupervised Pretraining) # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Standford Cars CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Aircrafts CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python co_tuning.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_co_tuning/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth ================================================ FILE: examples/task_adaptation/image_classification/convert_moco_to_pretrained.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import sys import torch if __name__ == "__main__": input = sys.argv[1] obj = torch.load(input, map_location="cpu") obj = obj["state_dict"] newmodel = {} fc = {} for k, v in obj.items(): if not k.startswith("module.encoder_q."): continue old_k = k k = k.replace("module.encoder_q.", "") if k.startswith("fc"): print(k) fc[k] = v else: newmodel[k] = v with open(sys.argv[2], "wb") as f: torch.save(newmodel, f) with open(sys.argv[3], "wb") as f: torch.save(fc, f) ================================================ FILE: examples/task_adaptation/image_classification/delta.py ================================================ """ @author: Yifei Ji, Junguang Jiang @contact: jiyf990330@163.com, JiangJunguang1123@outlook.com """ import math import os import random import time import warnings import sys import argparse import shutil import numpy as np from tqdm import tqdm import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.regularization.delta import * from tllib.modules.classifier import Classifier from tllib.utils.data import ForeverDataIterator from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform, val_transform, args.sample_rate, args.num_samples_per_classes) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_iter = ForeverDataIterator(train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset))) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, args.pretrained) backbone_source = utils.get_model(args.arch, args.pretrained) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device) source_classifier = Classifier(backbone_source, num_classes=backbone_source.fc.out_features, head=backbone_source.copy_head(), pool_layer=pool_layer).to(device) for param in source_classifier.parameters(): param.requires_grad = False source_classifier.eval() # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1 = utils.validate(val_loader, classifier, args, device) print(acc1) return # create intermediate layer getter if args.arch == 'resnet50': return_layers = ['backbone.layer1.2.conv3', 'backbone.layer2.3.conv3', 'backbone.layer3.5.conv3', 'backbone.layer4.2.conv3'] elif args.arch == 'resnet101': return_layers = ['backbone.layer1.2.conv3', 'backbone.layer2.3.conv3', 'backbone.layer3.5.conv3', 'backbone.layer4.2.conv3'] else: raise NotImplementedError(args.arch) source_getter = IntermediateLayerGetter(source_classifier, return_layers=return_layers) target_getter = IntermediateLayerGetter(classifier, return_layers=return_layers) # get regularization if args.regularization_type == 'l2_sp': backbone_regularization = SPRegularization(source_classifier.backbone, classifier.backbone) elif args.regularization_type == 'feature_map': backbone_regularization = BehavioralRegularization() elif args.regularization_type == 'attention_feature_map': attention_file = os.path.join(logger.root, args.attention_file) if not os.path.exists(attention_file): attention = calculate_channel_attention(train_dataset, return_layers, num_classes, args) torch.save(attention, attention_file) else: print("Loading channel attention from", attention_file) attention = torch.load(attention_file) attention = [a.to(device) for a in attention] backbone_regularization = AttentionBehavioralRegularization(attention) else: raise NotImplementedError(args.regularization_type) head_regularization = L2Regularization(nn.ModuleList([classifier.head, classifier.bottleneck])) # start training best_acc1 = 0.0 for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, backbone_regularization, head_regularization, target_getter, source_getter, optimizer, epoch, args) lr_scheduler.step() # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) logger.close() def calculate_channel_attention(dataset, return_layers, num_classes, args): backbone = utils.get_model(args.arch) classifier = Classifier(backbone, num_classes).to(device) optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) data_loader = DataLoader(dataset, batch_size=args.attention_batch_size, shuffle=True, num_workers=args.workers, drop_last=False) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=math.exp( math.log(0.1) / args.attention_lr_decay_epochs)) criterion = nn.CrossEntropyLoss() channel_weights = [] for layer_id, name in enumerate(return_layers): layer = get_attribute(classifier, name) layer_channel_weight = [0] * layer.out_channels channel_weights.append(layer_channel_weight) # train the classifier classifier.train() classifier.backbone.requires_grad = False print("Pretrain a classifier to calculate channel attention.") for epoch in range(args.attention_epochs): losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( len(data_loader), [losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) for i, data in enumerate(data_loader): inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs, _ = classifier(inputs) loss = criterion(outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() cls_acc = accuracy(outputs, labels)[0] losses.update(loss.item(), inputs.size(0)) cls_accs.update(cls_acc.item(), inputs.size(0)) if i % args.print_freq == 0: progress.display(i) lr_scheduler.step() # calculate the channel attention print('Calculating channel attention.') classifier.eval() if args.attention_iteration_limit > 0: total_iteration = min(len(data_loader), args.attention_iteration_limit) else: total_iteration = len(args.data_loader) progress = ProgressMeter( total_iteration, [], prefix="Iteration: ") for i, data in enumerate(data_loader): if i >= total_iteration: break inputs, labels = data inputs = inputs.to(device) labels = labels.to(device) outputs = classifier(inputs) loss_0 = criterion(outputs, labels) progress.display(i) for layer_id, name in enumerate(tqdm(return_layers)): layer = get_attribute(classifier, name) for j in range(layer.out_channels): tmp = classifier.state_dict()[name + '.weight'][j,].clone() classifier.state_dict()[name + '.weight'][j,] = 0.0 outputs = classifier(inputs) loss_1 = criterion(outputs, labels) difference = loss_1 - loss_0 difference = difference.detach().cpu().numpy().item() history_value = channel_weights[layer_id][j] channel_weights[layer_id][j] = 1.0 * (i * history_value + difference) / (i + 1) classifier.state_dict()[name + '.weight'][j,] = tmp channel_attention = [] for weight in channel_weights: weight = np.array(weight) weight = (weight - np.mean(weight)) / np.std(weight) weight = torch.from_numpy(weight).float().to(device) channel_attention.append(F.softmax(weight / 5).detach()) return channel_attention def train(train_iter: ForeverDataIterator, model: Classifier, backbone_regularization: nn.Module, head_regularization: nn.Module, target_getter: IntermediateLayerGetter, source_getter: IntermediateLayerGetter, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') losses_reg_head = AverageMeter('Loss (reg, head)', ':3.2f') losses_reg_backbone = AverageMeter('Loss (reg, backbone)', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, losses_reg_head, losses_reg_backbone, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels = next(train_iter) x = x.to(device) label = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output intermediate_output_s, output_s = source_getter(x) intermediate_output_t, output_t = target_getter(x) y, f = output_t # measure accuracy and record loss cls_acc = accuracy(y, label)[0] cls_loss = F.cross_entropy(y, label) if args.regularization_type == 'feature_map': loss_reg_backbone = backbone_regularization(intermediate_output_s, intermediate_output_t) elif args.regularization_type == 'attention_feature_map': loss_reg_backbone = backbone_regularization(intermediate_output_s, intermediate_output_t) else: loss_reg_backbone = backbone_regularization() loss_reg_head = head_regularization() loss = cls_loss + args.trade_off_backbone * loss_reg_backbone + args.trade_off_head * loss_reg_head losses_reg_backbone.update(loss_reg_backbone.item() * args.trade_off_backbone, x.size(0)) losses_reg_head.update(loss_reg_head.item() * args.trade_off_head, x.size(0)) losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Delta for Finetuning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training') parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor. Used in models such as ViT.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") parser.add_argument('--regularization-type', choices=['l2_sp', 'feature_map', 'attention_feature_map'], default='attention_feature_map') parser.add_argument('--trade-off-backbone', default=0.01, type=float, help='trade-off for backbone regularization') parser.add_argument('--trade-off-head', default=0.01, type=float, help='trade-off for head regularization') # training parameters parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0., type=float, metavar='W', help='weight decay (default: 0.)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='delta', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model." "When phase is 'analysis', only analysis the model.") # parameters for calculating channel attention parser.add_argument("--attention-file", type=str, default='channel_attention.pt', help="Where to save and load channel attention file.") parser.add_argument('--attention-batch-size', default=32, type=int, metavar='N', help='mini-batch size for calculating channel attention (default: 32)') parser.add_argument('--attention-epochs', default=10, type=int, metavar='N', help='number of epochs to train for training before calculating channel weight') parser.add_argument('--attention-lr-decay-epochs', default=6, type=int, metavar='N', help='epochs to decay lr for training before calculating channel weight') parser.add_argument('--attention-iteration-limit', default=10, type=int, metavar='N', help='iteration limits for calculating channel attention, -1 means no limits') args = parser.parse_args() main(args) ================================================ FILE: examples/task_adaptation/image_classification/delta.sh ================================================ #!/usr/bin/env bash # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/delta/cub200_100 CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/delta/cub200_50 CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/delta/cub200_30 CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/delta/cub200_15 # Stanford Cars CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/delta/car_100 CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/delta/car_50 CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/delta/car_30 CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/delta/car_15 # Aircrafts CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/delta/aircraft_100 CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/delta/aircraft_50 CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/delta/aircraft_30 CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/delta/aircraft_15 # CIFAR10 CUDA_VISIBLE_DEVICES=0 python delta.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/delta/cifar10/1e-2 --lr 1e-2 # CIFAR100 CUDA_VISIBLE_DEVICES=0 python delta.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/delta/cifar100/1e-2 --lr 1e-2 # Flowers CUDA_VISIBLE_DEVICES=0 python delta.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/delta/oxford_flowers102/1e-2 --lr 1e-2 # Pets CUDA_VISIBLE_DEVICES=0 python delta.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/delta/oxford_pet/1e-2 --lr 1e-2 # DTD CUDA_VISIBLE_DEVICES=0 python delta.py data/dtd -d DTD --seed 0 --finetune --log logs/delta/dtd/1e-2 --lr 1e-2 # caltech101 CUDA_VISIBLE_DEVICES=0 python delta.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/delta/caltech101/lr_1e-3 --lr 1e-3 # SUN397 CUDA_VISIBLE_DEVICES=0 python delta.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/delta/sun397/lr_1e-2 --lr 1e-2 # Food 101 CUDA_VISIBLE_DEVICES=0 python delta.py data/food-101 -d Food101 --seed 0 --finetune --log logs/delta/food-101/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/delta/stanford_cars/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/delta/aircraft/lr_1e-2 --lr 1e-2 # MoCo (Unsupervised Pretraining) # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/cub200_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/cub200_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/cub200_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/cub200_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth # Standford Cars CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/cars_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/cars_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/cars_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/cars_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth # Aircrafts CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/aircraft_100 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/aircraft_50 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/aircraft_30 --pretrained checkpoints/moco_v1_200ep_pretrain.pth CUDA_VISIBLE_DEVICES=0 python delta.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_delta/aircraft_15 --pretrained checkpoints/moco_v1_200ep_pretrain.pth ================================================ FILE: examples/task_adaptation/image_classification/erm.py ================================================ """ @author: Yifei Ji, Junguang Jiang @contact: jiyf990330@163.com, JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.modules.classifier import Classifier from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform, val_transform, args.sample_rate, args.num_samples_per_classes) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_iter = ForeverDataIterator(train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset))) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, args.pretrained) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1 = utils.validate(val_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0.0 for epoch in range(args.epochs): logger.set_epoch(epoch) print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, epoch, args) lr_scheduler.step() # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels = next(train_iter) x = x.to(device) label = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, f = model(x) cls_loss = F.cross_entropy(y, label) loss = cls_loss # measure accuracy and record loss losses.update(loss.item(), x.size(0)) cls_acc = accuracy(y, label)[0] cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Baseline for Finetuning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training') parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor. Used in models such as ViT.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") # training parameters parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--optimizer', type=str, default='SGD', choices=['SGD', 'Adam']) parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='baseline', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/task_adaptation/image_classification/erm.sh ================================================ #!/usr/bin/env bash # Supervised Pretraining # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/erm/cub200_100 CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/erm/cub200_50 CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/erm/cub200_30 CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/erm/cub200_15 # Standford Cars CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/erm/car_100 CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/erm/car_50 CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/erm/car_30 CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/erm/car_15 # Aircrafts CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/erm/aircraft_100 CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/erm/aircraft_50 CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/erm/aircraft_30 CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/erm/aircraft_15 # CIFAR10 CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/erm/cifar10/1e-2 --lr 1e-2 # CIFAR100 CUDA_VISIBLE_DEVICES=0 python erm.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/erm/cifar100/1e-2 --lr 1e-2 # Flowers CUDA_VISIBLE_DEVICES=0 python erm.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/erm/oxford_flowers102/1e-2 --lr 1e-2 # Pets CUDA_VISIBLE_DEVICES=0 python erm.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/erm/oxford_pet/1e-2 --lr 1e-2 # DTD CUDA_VISIBLE_DEVICES=0 python erm.py data/dtd -d DTD --seed 0 --finetune --log logs/erm/dtd/1e-2 --lr 1e-2 # caltech101 CUDA_VISIBLE_DEVICES=0 python erm.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/erm/caltech101/lr_1e-3 --lr 1e-3 # SUN397 CUDA_VISIBLE_DEVICES=0 python erm.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/erm/sun397/lr_1e-2 --lr 1e-2 # Food 101 CUDA_VISIBLE_DEVICES=0 python erm.py data/food-101 -d Food101 --seed 0 --finetune --log logs/erm/food-101/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/erm/stanford_cars/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/erm/aircraft/lr_1e-2 --lr 1e-2 # MoCo (Unsupervised Pretraining) #CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Standford Cars CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Aircrafts CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python erm.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_erm/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth ================================================ FILE: examples/task_adaptation/image_classification/lwf.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader, TensorDataset import torch.nn.functional as F import utils from tllib.regularization.lwf import collect_pretrain_labels, Classifier from tllib.regularization.knowledge_distillation import KnowledgeDistillationLoss from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.logger import CompleteLogger from tllib.utils.data import ForeverDataIterator, CombineDataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform, val_transform, args.sample_rate, args.num_samples_per_classes) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, drop_last=False) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset))) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, args.pretrained) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, head_source=backbone.copy_head(), pool_layer=pool_layer, finetune=args.finetune).to(device) kd = KnowledgeDistillationLoss(args.T) source_classifier = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.head_source) pretrain_labels = collect_pretrain_labels(train_loader, source_classifier, device) train_dataset = CombineDataset([train_dataset, TensorDataset(pretrain_labels)]) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_iter = ForeverDataIterator(train_loader) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(args.lr), momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1 = utils.validate(val_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0.0 for epoch in range(args.epochs): # train for one epoch train(train_iter, classifier, kd, optimizer, epoch, args) lr_scheduler.step() # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model: Classifier, kd, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') losses_kd = AverageMeter('Loss (KD)', ':5.4f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, losses_kd, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, label_t, label_s = next(train_iter) x = x.to(device) label_s = label_s.to(device) label_t = label_t.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y_s, y_t = model(x) tgt_loss = F.cross_entropy(y_t, label_t) src_loss = kd(y_s, label_s) loss = tgt_loss + args.trade_off * src_loss # measure accuracy and record loss losses.update(tgt_loss.item(), x.size(0)) losses_kd.update(src_loss.item(), x.size(0)) cls_acc = accuracy(y_t, label_t)[0] cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='LWF (Learning without Forgetting) for Finetuning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training') parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor. Used in models such as ViT.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--trade-off', default=4, type=float, metavar='P', help='weight of pretrained loss') parser.add_argument("-T", type=float, default=3, help="temperature for knowledge distillation") parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") # training parameters parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='lwf', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/task_adaptation/image_classification/lwf.sh ================================================ #!/usr/bin/env bash # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/lwf/cub200_100 --lr 0.01 CUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/lwf/cub200_50 --lr 0.001 CUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/lwf/cub200_30 --lr 0.001 CUDA_VISIBLE_DEVICES=0 python lwf.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/lwf/cub200_15 --lr 0.001 # Standford Cars CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/lwf/car_100 --lr 0.01 CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/lwf/car_50 --lr 0.01 CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/lwf/car_30 --lr 0.01 CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/lwf/car_15 --lr 0.01 # Aircrafts CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/lwf/aircraft_100 --lr 0.001 CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/lwf/aircraft_50 --lr 0.001 CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/lwf/aircraft_30 --lr 0.001 CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/lwf/aircraft_15 --lr 0.001 # CIFAR10 CUDA_VISIBLE_DEVICES=0 python lwf.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/lwf/cifar10/1e-2 --lr 1e-2 # CIFAR100 CUDA_VISIBLE_DEVICES=0 python lwf.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/lwf/cifar100/1e-2 --lr 1e-2 # Flowers CUDA_VISIBLE_DEVICES=0 python lwf.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/lwf/oxford_flowers102/1e-2 --lr 1e-2 # Pets CUDA_VISIBLE_DEVICES=0 python lwf.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/lwf/oxford_pet/1e-2 --lr 1e-2 # DTD CUDA_VISIBLE_DEVICES=0 python lwf.py data/dtd -d DTD --seed 0 --finetune --log logs/lwf/dtd/1e-2 --lr 1e-2 # caltech101 CUDA_VISIBLE_DEVICES=0 python lwf.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/lwf/caltech101/lr_1e-3 --lr 1e-3 # SUN397 CUDA_VISIBLE_DEVICES=0 python lwf.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/lwf/sun397/lr_1e-2 --lr 1e-2 # Food 101 CUDA_VISIBLE_DEVICES=0 python lwf.py data/food-101 -d Food101 --seed 0 --finetune --log logs/lwf/food-101/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python lwf.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/lwf/stanford_cars/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python lwf.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/lwf/aircraft/lr_1e-2 --lr 1e-2 ================================================ FILE: examples/task_adaptation/image_classification/requirements.txt ================================================ timm ================================================ FILE: examples/task_adaptation/image_classification/stochnorm.py ================================================ """ @author: Yifei Ji, Junguang Jiang @contact: jiyf990330@163.com, JiangJunguang1123@outlook.com """ import random import time import warnings import argparse import shutil import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.optim import SGD from torch.utils.data import DataLoader import torch.nn.functional as F import utils from tllib.normalization.stochnorm import convert_model from tllib.modules.classifier import Classifier from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.utils.data import ForeverDataIterator from tllib.utils.logger import CompleteLogger device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def main(args: argparse.Namespace): logger = CompleteLogger(args.log, args.phase) print(args) if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) cudnn.deterministic = True warnings.warn('You have chosen to seed training. ' 'This will turn on the CUDNN deterministic setting, ' 'which can slow down your training considerably! ' 'You may see unexpected behavior when restarting ' 'from checkpoints.') cudnn.benchmark = True # Data loading code train_transform = utils.get_train_transform(args.train_resizing, not args.no_hflip, args.color_jitter) val_transform = utils.get_val_transform(args.val_resizing) print("train_transform: ", train_transform) print("val_transform: ", val_transform) train_dataset, val_dataset, num_classes = utils.get_dataset(args.data, args.root, train_transform, val_transform, args.sample_rate, args.num_samples_per_classes) train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, drop_last=True) train_iter = ForeverDataIterator(train_loader) val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) print("training dataset size: {} test dataset size: {}".format(len(train_dataset), len(val_dataset))) # create model print("=> using pre-trained model '{}'".format(args.arch)) backbone = utils.get_model(args.arch, args.pretrained) pool_layer = nn.Identity() if args.no_pool else None classifier = Classifier(backbone, num_classes, pool_layer=pool_layer, finetune=args.finetune).to(device) classifier = convert_model(classifier, p=args.prob) # define optimizer and lr scheduler optimizer = SGD(classifier.get_parameters(args.lr), lr=args.lr, momentum=args.momentum, weight_decay=args.wd, nesterov=True) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_decay_epochs, gamma=args.lr_gamma) # resume from the best checkpoint if args.phase == 'test': checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') classifier.load_state_dict(checkpoint) acc1 = utils.validate(val_loader, classifier, args, device) print(acc1) return # start training best_acc1 = 0.0 for epoch in range(args.epochs): print(lr_scheduler.get_lr()) # train for one epoch train(train_iter, classifier, optimizer, epoch, args) lr_scheduler.step() # evaluate on validation set acc1 = utils.validate(val_loader, classifier, args, device) # remember best acc@1 and save checkpoint torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) if acc1 > best_acc1: shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) best_acc1 = max(acc1, best_acc1) print("best_acc1 = {:3.1f}".format(best_acc1)) logger.close() def train(train_iter: ForeverDataIterator, model: Classifier, optimizer: SGD, epoch: int, args: argparse.Namespace): batch_time = AverageMeter('Time', ':4.2f') data_time = AverageMeter('Data', ':3.1f') losses = AverageMeter('Loss', ':3.2f') cls_accs = AverageMeter('Cls Acc', ':3.1f') progress = ProgressMeter( args.iters_per_epoch, [batch_time, data_time, losses, cls_accs], prefix="Epoch: [{}]".format(epoch)) # switch to train mode model.train() end = time.time() for i in range(args.iters_per_epoch): x, labels = next(train_iter) x = x.to(device) label = labels.to(device) # measure data loading time data_time.update(time.time() - end) # compute output y, f = model(x) cls_loss = F.cross_entropy(y, label) loss = cls_loss cls_acc = accuracy(y, label)[0] losses.update(loss.item(), x.size(0)) cls_accs.update(cls_acc.item(), x.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if __name__ == '__main__': parser = argparse.ArgumentParser(description='StochNorm for Finetuning') # dataset parameters parser.add_argument('root', metavar='DIR', help='root path of dataset') parser.add_argument('-d', '--data', metavar='DATA') parser.add_argument('-sr', '--sample-rate', default=100, type=int, metavar='N', help='sample rate of training dataset (default: 100)') parser.add_argument('-sc', '--num-samples-per-classes', default=None, type=int, help='number of samples per classes.') parser.add_argument('--train-resizing', type=str, default='default', help='resize mode during training') parser.add_argument('--val-resizing', type=str, default='default', help='resize mode during validation') parser.add_argument('--no-hflip', action='store_true', help='no random horizontal flipping during training') parser.add_argument('--color-jitter', action='store_true', help='apply jitter during training') # model parameters parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', choices=utils.get_model_names(), help='backbone architecture: ' + ' | '.join(utils.get_model_names()) + ' (default: resnet50)') parser.add_argument('--no-pool', action='store_true', help='no pool layer after the feature extractor. Used in models such as ViT.') parser.add_argument('--finetune', action='store_true', help='whether use 10x smaller lr for backbone') parser.add_argument('--prob', '--probability', default=0.5, type=float, metavar='P', help='Probability for StochNorm layers') parser.add_argument('--pretrained', default=None, help="pretrained checkpoint of the backbone. " "(default: None, use the ImageNet supervised pretrained backbone)") # training parameters parser.add_argument('-b', '--batch-size', default=48, type=int, metavar='N', help='mini-batch size (default: 48)') parser.add_argument('--lr', '--learning-rate', default=0.01, type=float, metavar='LR', help='initial learning rate', dest='lr') parser.add_argument('--lr-gamma', default=0.1, type=float, help='parameter for lr scheduler') parser.add_argument('--lr-decay-epochs', type=int, default=(12,), nargs='+', help='epochs to decay lr') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=0.0005, type=float, metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('--epochs', default=20, type=int, metavar='N', help='number of total epochs to run') parser.add_argument('-i', '--iters-per-epoch', default=500, type=int, help='Number of iterations per epoch') parser.add_argument('-p', '--print-freq', default=100, type=int, metavar='N', help='print frequency (default: 100)') parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') parser.add_argument("--log", type=str, default='stochnorm', help="Where to save logs, checkpoints and debugging images.") parser.add_argument("--phase", type=str, default='train', choices=['train', 'test'], help="When phase is 'test', only test the model.") args = parser.parse_args() main(args) ================================================ FILE: examples/task_adaptation/image_classification/stochnorm.sh ================================================ #!/usr/bin/env bash # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 100 --seed 0 --finetune --log logs/stochnorm/cub200_100 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 50 --seed 0 --finetune --log logs/stochnorm/cub200_50 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 30 --seed 0 --finetune --log logs/stochnorm/cub200_30 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 15 --seed 0 --finetune --log logs/stochnorm/cub200_15 # Standford Cars CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --finetune --log logs/stochnorm/car_100 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --finetune --log logs/stochnorm/car_50 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --finetune --log logs/stochnorm/car_30 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --finetune --log logs/stochnorm/car_15 # Aircrafts CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 100 --seed 0 --finetune --log logs/stochnorm/aircraft_100 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 50 --seed 0 --finetune --log logs/stochnorm/aircraft_50 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 30 --seed 0 --finetune --log logs/stochnorm/aircraft_30 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 15 --seed 0 --finetune --log logs/stochnorm/aircraft_15 # CIFAR10 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cifar10 -d CIFAR10 --seed 0 --finetune --log logs/stochnorm/cifar10/1e-2 --lr 1e-2 # CIFAR100 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cifar100 -d CIFAR100 --seed 0 --finetune --log logs/stochnorm/cifar100/1e-2 --lr 1e-2 # Flowers CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/oxford_flowers102 -d OxfordFlowers102 --seed 0 --finetune --log logs/stochnorm/oxford_flowers102/1e-2 --lr 1e-2 # Pets CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/oxford_pet -d OxfordIIITPets --seed 0 --finetune --log logs/stochnorm/oxford_pet/1e-2 --lr 1e-2 # DTD CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/dtd -d DTD --seed 0 --finetune --log logs/stochnorm/dtd/1e-2 --lr 1e-2 # caltech101 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/caltech101 -d Caltech101 --seed 0 --finetune --log logs/stochnorm/caltech101/lr_1e-3 --lr 1e-3 # SUN397 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/sun397 -d SUN397 --seed 0 --finetune --log logs/stochnorm/sun397/lr_1e-2 --lr 1e-2 # Food 101 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/food-101 -d Food101 --seed 0 --finetune --log logs/stochnorm/food-101/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars --seed 0 --finetune --log logs/stochnorm/stanford_cars/lr_1e-2 --lr 1e-2 # Standford Cars CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft --seed 0 --finetune --log logs/stochnorm/aircraft/lr_1e-2 --lr 1e-2 # MoCo (Unsupervised Pretraining) # CUB-200-2011 CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/cub200_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/cub200_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/cub200_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/cub200 -d CUB200 -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/cub200_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Standford Cars CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/cars_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/cars_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/cars_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/stanford_cars -d StanfordCars -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/cars_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth # Aircrafts CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 100 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/aircraft_100 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 50 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/aircraft_50 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 30 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/aircraft_30 --pretrained checkpoints/moco_v1_200ep_backbone.pth CUDA_VISIBLE_DEVICES=0 python stochnorm.py data/aircraft -d Aircraft -sr 15 --seed 0 --lr 0.1 --finetune -i 2000 --lr-decay-epochs 3 6 9 --epochs 12 \ --log logs/moco_pretrain_stochnorm/aircraft_15 --pretrained checkpoints/moco_v1_200ep_backbone.pth ================================================ FILE: examples/task_adaptation/image_classification/utils.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import time from PIL import Image import timm import numpy as np import random import sys import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Subset import torchvision.transforms as T from torch.optim import SGD, Adam sys.path.append('../../..') import tllib.vision.datasets as datasets import tllib.vision.models as models from tllib.utils.metric import accuracy from tllib.utils.meter import AverageMeter, ProgressMeter from tllib.vision.transforms import Denormalize def get_model_names(): return sorted( name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) ) + timm.list_models() def get_model(model_name, pretrained_checkpoint=None): if model_name in models.__dict__: # load models from tllib.vision.models backbone = models.__dict__[model_name](pretrained=True) else: # load models from pytorch-image-models backbone = timm.create_model(model_name, pretrained=True) try: backbone.out_features = backbone.get_classifier().in_features backbone.reset_classifier(0, '') backbone.copy_head = backbone.get_classifier except: backbone.out_features = backbone.head.in_features backbone.head = nn.Identity() backbone.copy_head = lambda x: x.head if pretrained_checkpoint: print("=> loading pre-trained model from '{}'".format(pretrained_checkpoint)) pretrained_dict = torch.load(pretrained_checkpoint) backbone.load_state_dict(pretrained_dict, strict=False) return backbone def get_dataset(dataset_name, root, train_transform, val_transform, sample_rate=100, num_samples_per_classes=None): """ When sample_rate < 100, e.g. sample_rate = 50, use 50% data to train the model. Otherwise, if num_samples_per_classes is not None, e.g. 5, then sample 5 images for each class, and use them to train the model; otherwise, keep all the data. """ dataset = datasets.__dict__[dataset_name] if sample_rate < 100: train_dataset = dataset(root=root, split='train', sample_rate=sample_rate, download=True, transform=train_transform) test_dataset = dataset(root=root, split='test', sample_rate=100, download=True, transform=val_transform) num_classes = train_dataset.num_classes else: train_dataset = dataset(root=root, split='train', download=True, transform=train_transform) test_dataset = dataset(root=root, split='test', download=True, transform=val_transform) num_classes = train_dataset.num_classes if num_samples_per_classes is not None: samples = list(range(len(train_dataset))) random.shuffle(samples) samples_len = min(num_samples_per_classes * num_classes, len(train_dataset)) print("Origin dataset:", len(train_dataset), "Sampled dataset:", samples_len, "Ratio:", float(samples_len) / len(train_dataset)) train_dataset = Subset(train_dataset, samples[:samples_len]) return train_dataset, test_dataset, num_classes def validate(val_loader, model, args, device, visualize=None) -> float: batch_time = AverageMeter('Time', ':6.3f') losses = AverageMeter('Loss', ':.4e') top1 = AverageMeter('Acc@1', ':6.2f') progress = ProgressMeter( len(val_loader), [batch_time, losses, top1], prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() for i, (images, target) in enumerate(val_loader): images = images.to(device) target = target.to(device) # compute output output = model(images) loss = F.cross_entropy(output, target) # measure accuracy and record loss acc1, = accuracy(output, target, topk=(1, )) losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i) if visualize is not None: visualize(images[0], "val_{}".format(i)) print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) return top1.avg def get_train_transform(resizing='default', random_horizontal_flip=True, random_color_jitter=False): """ resizing mode: - default: take a random resized crop of size 224 with scale in [0.2, 1.]; - res: resize the image to 224; - res.|crop: resize the image to 256 and take a random crop of size 224; - res.sma|crop: resize the image keeping its aspect ratio such that the smaller side is 256, then take a random crop of size 224; – inc.crop: “inception crop” from (Szegedy et al., 2015); – cif.crop: resize the image to 224, zero-pad it by 28 on each side, then take a random crop of size 224. """ if resizing == 'default': transform = T.RandomResizedCrop(224, scale=(0.2, 1.)) elif resizing == 'res.': transform = T.Resize((224, 224)) elif resizing == 'res.|crop': transform = T.Compose([ T.Resize((256, 256)), T.RandomCrop(224) ]) elif resizing == "res.sma|crop": transform = T.Compose([ T.Resize(256), T.RandomCrop(224) ]) elif resizing == 'inc.crop': transform = T.RandomResizedCrop(224) elif resizing == 'cif.crop': transform = T.Compose([ T.Resize((224, 224)), T.Pad(28), T.RandomCrop(224), ]) else: raise NotImplementedError(resizing) transforms = [transform] if random_horizontal_flip: transforms.append(T.RandomHorizontalFlip()) if random_color_jitter: transforms.append(T.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)) transforms.extend([ T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return T.Compose(transforms) def get_val_transform(resizing='default'): """ resizing mode: - default: resize the image to 256 and take the center crop of size 224; – res.: resize the image to 224 – res.|crop: resize the image such that the smaller side is of size 256 and then take a central crop of size 224. """ if resizing == 'default': transform = T.Compose([ T.Resize(256), T.CenterCrop(224), ]) elif resizing == 'res.': transform = T.Resize((224, 224)) elif resizing == 'res.|crop': transform = T.Compose([ T.Resize((256, 256)), T.CenterCrop(224), ]) else: raise NotImplementedError(resizing) return T.Compose([ transform, T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def get_optimizer(optimizer_name, params, lr, wd, momentum): ''' Args: optimizer_name: - SGD - Adam params: iterable of parameters to optimize or dicts defining parameter groups lr: learning rate weight_decay: weight decay momentum: momentum factor for SGD ''' if optimizer_name == 'SGD': optimizer = SGD(params=params, lr=lr, momentum=momentum, weight_decay=wd, nesterov=True) elif optimizer_name == 'Adam': optimizer = Adam(params=params, lr=lr, weight_decay=wd) else: raise NotImplementedError(optimizer_name) return optimizer def visualize(image, filename): """ Args: image (tensor): 3 x H x W filename: filename of the saving image """ image = image.detach().cpu() image = Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image) image = image.numpy().transpose((1, 2, 0)) * 255 Image.fromarray(np.uint8(image)).save(filename) ================================================ FILE: requirements.txt ================================================ torch>=1.7.0 torchvision>=0.5.0 numpy prettytable tqdm scikit-learn webcolors matplotlib opencv-python numba ================================================ FILE: setup.py ================================================ from setuptools import setup, find_packages import re from os import path here = path.abspath(path.dirname(__file__)) # Get the version string with open(path.join(here, 'tllib', '__init__.py')) as f: version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) # Get all runtime requirements REQUIRES = [] with open('requirements.txt') as f: for line in f: line, _, _ = line.partition('#') line = line.strip() REQUIRES.append(line) if __name__ == '__main__': setup( name="tllib", # Replace with your own username version=version, author="THUML", author_email="JiangJunguang1123@outlook.com", keywords="domain adaptation, task adaptation, domain generalization, " "transfer learning, deep learning, pytorch", description="A Transfer Learning Library for Domain Adaptation, Task Adaptation, and Domain Generalization", long_description=open('README.md', encoding='utf8').read(), long_description_content_type="text/markdown", url="https://github.com/thuml/Transfer-Learning-Library", packages=find_packages(exclude=['docs', 'examples']), classifiers=[ # How mature is this project? Common values are # 3 - Alpha # 4 - Beta # 5 - Production/Stable 'Development Status :: 3 - Alpha', # Indicate who your project is intended for 'Intended Audience :: Science/Research', 'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Software Development :: Libraries :: Python Modules', # Pick your license as you wish (should match "license" above) 'License :: OSI Approved :: MIT License', # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', ], python_requires='>=3.6', install_requires=REQUIRES, extras_require={ 'dev': [ 'Sphinx', 'sphinx_rtd_theme', ] }, ) ================================================ FILE: tllib/__init__.py ================================================ from . import alignment from . import self_training from . import translation from . import regularization from . import utils from . import vision from . import modules from . import ranking __version__ = '0.4' __all__ = ['alignment', 'self_training', 'translation', 'regularization', 'utils', 'vision', 'modules', 'ranking'] ================================================ FILE: tllib/alignment/__init__.py ================================================ from . import cdan from . import dann from . import mdd from . import dan from . import jan from . import mcd from . import osbp from . import adda from . import bsp ================================================ FILE: tllib/alignment/adda.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from typing import Optional, List, Dict import torch import torch.nn as nn import torch.nn.functional as F from tllib.modules.classifier import Classifier as ClassifierBase class DomainAdversarialLoss(nn.Module): r"""Domain adversarial loss from `Adversarial Discriminative Domain Adaptation (CVPR 2017) `_. Similar to the original `GAN `_ paper, ADDA argues that replacing :math:`\text{log}(1-p)` with :math:`-\text{log}(p)` in the adversarial loss provides better gradient qualities. Detailed optimization process can be found `here `_. Inputs: - domain_pred (tensor): predictions of domain discriminator - domain_label (str, optional): whether the data comes from source or target. Must be 'source' or 'target'. Default: 'source' Shape: - domain_pred: :math:`(minibatch,)`. - Outputs: scalar. """ def __init__(self): super(DomainAdversarialLoss, self).__init__() def forward(self, domain_pred, domain_label='source'): assert domain_label in ['source', 'target'] if domain_label == 'source': return F.binary_cross_entropy(domain_pred, torch.ones_like(domain_pred).to(domain_pred.device)) else: return F.binary_cross_entropy(domain_pred, torch.zeros_like(domain_pred).to(domain_pred.device)) class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) def freeze_bn(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): m.eval() def get_parameters(self, base_lr=1.0, optimize_head=True) -> List[Dict]: params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr} ] if optimize_head: params.append({"params": self.head.parameters(), "lr": 1.0 * base_lr}) return params ================================================ FILE: tllib/alignment/advent.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from torch import nn import torch import torch.nn.functional as F import numpy as np class Discriminator(nn.Sequential): """ Domain discriminator model from `ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation (CVPR 2019) `_ Distinguish pixel-by-pixel whether the input predictions come from the source domain or the target domain. The source domain label is 1 and the target domain label is 0. Args: num_classes (int): num of classes in the predictions ndf (int): dimension of the hidden features Shape: - Inputs: :math:`(minibatch, C, H, W)` where :math:`C` is the number of classes - Outputs: :math:`(minibatch, 1, H, W)` """ def __init__(self, num_classes, ndf=64): super(Discriminator, self).__init__( nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1), ) def prob_2_entropy(prob): """ convert probabilistic prediction maps to weighted self-information maps """ n, c, h, w = prob.size() return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c) def bce_loss(y_pred, y_label): y_truth_tensor = torch.FloatTensor(y_pred.size()) y_truth_tensor.fill_(y_label) y_truth_tensor = y_truth_tensor.to(y_pred.get_device()) return F.binary_cross_entropy_with_logits(y_pred, y_truth_tensor) class DomainAdversarialEntropyLoss(nn.Module): r"""The `Domain Adversarial Entropy Loss `_ Minimizing entropy with adversarial learning through training a domain discriminator. Args: domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of predictions. Its input shape is :math:`(minibatch, C, H, W)` and output shape is :math:`(minibatch, 1, H, W)` Inputs: - logits (tensor): logits output of segmentation model - domain_label (str, optional): whether the data comes from source or target. Choices: ['source', 'target']. Default: 'source' Shape: - logits: :math:`(minibatch, C, H, W)` where :math:`C` means the number of classes - Outputs: scalar. Examples:: >>> B, C, H, W = 2, 19, 512, 512 >>> discriminator = Discriminator(num_classes=C) >>> dann = DomainAdversarialEntropyLoss(discriminator) >>> # logits output on source domain and target domain >>> y_s, y_t = torch.randn(B, C, H, W), torch.randn(B, C, H, W) >>> loss = 0.5 * (dann(y_s, "source") + dann(y_t, "target")) """ def __init__(self, discriminator: nn.Module): super(DomainAdversarialEntropyLoss, self).__init__() self.discriminator = discriminator def forward(self, logits, domain_label='source'): """ """ assert domain_label in ['source', 'target'] probability = F.softmax(logits, dim=1) entropy = prob_2_entropy(probability) domain_prediciton = self.discriminator(entropy) if domain_label == 'source': return bce_loss(domain_prediciton, 1) else: return bce_loss(domain_prediciton, 0) def train(self, mode=True): r"""Sets the discriminator in training mode. In the training mode, all the parameters in discriminator will be set requires_grad=True. Args: mode (bool): whether to set training mode (``True``) or evaluation mode (``False``). Default: ``True``. """ self.discriminator.train(mode) for param in self.discriminator.parameters(): param.requires_grad = mode return self def eval(self): r"""Sets the module in evaluation mode. In the training mode, all the parameters in discriminator will be set requires_grad=False. This is equivalent with :meth:`self.train(False) `. """ return self.train(False) ================================================ FILE: tllib/alignment/bsp.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from typing import Optional import torch import torch.nn as nn from tllib.modules.classifier import Classifier as ClassifierBase class BatchSpectralPenalizationLoss(nn.Module): r"""Batch spectral penalization loss from `Transferability vs. Discriminability: Batch Spectral Penalization for Adversarial Domain Adaptation (ICML 2019) `_. Given source features :math:`f_s` and target features :math:`f_t` in current mini batch, singular value decomposition is first performed .. math:: f_s = U_s\Sigma_sV_s^T .. math:: f_t = U_t\Sigma_tV_t^T Then batch spectral penalization loss is calculated as .. math:: loss=\sum_{i=1}^k(\sigma_{s,i}^2+\sigma_{t,i}^2) where :math:`\sigma_{s,i},\sigma_{t,i}` refer to the :math:`i-th` largest singular value of source features and target features respectively. We empirically set :math:`k=1`. Inputs: - f_s (tensor): feature representations on source domain, :math:`f^s` - f_t (tensor): feature representations on target domain, :math:`f^t` Shape: - f_s, f_t: :math:`(N, F)` where F means the dimension of input features. - Outputs: scalar. """ def __init__(self): super(BatchSpectralPenalizationLoss, self).__init__() def forward(self, f_s, f_t): _, s_s, _ = torch.svd(f_s) _, s_t, _ = torch.svd(f_t) loss = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2) return loss class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) ================================================ FILE: tllib/alignment/cdan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from tllib.modules.classifier import Classifier as ClassifierBase from tllib.utils.metric import binary_accuracy, accuracy from tllib.modules.grl import WarmStartGradientReverseLayer from tllib.modules.entropy import entropy __all__ = ['ConditionalDomainAdversarialLoss', 'ImageClassifier'] class ConditionalDomainAdversarialLoss(nn.Module): r"""The Conditional Domain Adversarial Loss used in `Conditional Adversarial Domain Adaptation (NIPS 2018) `_ Conditional Domain adversarial loss measures the domain discrepancy through training a domain discriminator in a conditional manner. Given domain discriminator :math:`D`, feature representation :math:`f` and classifier predictions :math:`g`, the definition of CDAN loss is .. math:: loss(\mathcal{D}_s, \mathcal{D}_t) &= \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(T(f_i^s, g_i^s))] \\ &+ \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(T(f_j^t, g_j^t))],\\ where :math:`T` is a :class:`MultiLinearMap` or :class:`RandomizedMultiLinearMap` which convert two tensors to a single tensor. Args: domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1) entropy_conditioning (bool, optional): If True, use entropy-aware weight to reweight each training example. Default: False randomized (bool, optional): If True, use `randomized multi linear map`. Else, use `multi linear map`. Default: False num_classes (int, optional): Number of classes. Default: -1 features_dim (int, optional): Dimension of input features. Default: -1 randomized_dim (int, optional): Dimension of features after randomized. Default: 1024 reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` .. note:: You need to provide `num_classes`, `features_dim` and `randomized_dim` **only when** `randomized` is set True. Inputs: - g_s (tensor): unnormalized classifier predictions on source domain, :math:`g^s` - f_s (tensor): feature representations on source domain, :math:`f^s` - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t` - f_t (tensor): feature representations on target domain, :math:`f^t` Shape: - g_s, g_t: :math:`(minibatch, C)` where C means the number of classes. - f_s, f_t: :math:`(minibatch, F)` where F means the dimension of input features. - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, )`. Examples:: >>> from tllib.modules.domain_discriminator import DomainDiscriminator >>> from tllib.alignment.cdan import ConditionalDomainAdversarialLoss >>> import torch >>> num_classes = 2 >>> feature_dim = 1024 >>> batch_size = 10 >>> discriminator = DomainDiscriminator(in_feature=feature_dim * num_classes, hidden_size=1024) >>> loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean') >>> # features from source domain and target domain >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # logits output from source domain adn target domain >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> output = loss(g_s, f_s, g_t, f_t) """ def __init__(self, domain_discriminator: nn.Module, entropy_conditioning: Optional[bool] = False, randomized: Optional[bool] = False, num_classes: Optional[int] = -1, features_dim: Optional[int] = -1, randomized_dim: Optional[int] = 1024, reduction: Optional[str] = 'mean', sigmoid=True): super(ConditionalDomainAdversarialLoss, self).__init__() self.domain_discriminator = domain_discriminator self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True) self.entropy_conditioning = entropy_conditioning self.sigmoid = sigmoid self.reduction = reduction if randomized: assert num_classes > 0 and features_dim > 0 and randomized_dim > 0 self.map = RandomizedMultiLinearMap(features_dim, num_classes, randomized_dim) else: self.map = MultiLinearMap() self.bce = lambda input, target, weight: F.binary_cross_entropy(input, target, weight, reduction=reduction) if self.entropy_conditioning \ else F.binary_cross_entropy(input, target, reduction=reduction) self.domain_discriminator_accuracy = None def forward(self, g_s: torch.Tensor, f_s: torch.Tensor, g_t: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor: f = torch.cat((f_s, f_t), dim=0) g = torch.cat((g_s, g_t), dim=0) g = F.softmax(g, dim=1).detach() h = self.grl(self.map(f, g)) d = self.domain_discriminator(h) weight = 1.0 + torch.exp(-entropy(g)) batch_size = f.size(0) weight = weight / torch.sum(weight) * batch_size if self.sigmoid: d_label = torch.cat(( torch.ones((g_s.size(0), 1)).to(g_s.device), torch.zeros((g_t.size(0), 1)).to(g_t.device), )) self.domain_discriminator_accuracy = binary_accuracy(d, d_label) if self.entropy_conditioning: return F.binary_cross_entropy(d, d_label, weight.view_as(d), reduction=self.reduction) else: return F.binary_cross_entropy(d, d_label, reduction=self.reduction) else: d_label = torch.cat(( torch.ones((g_s.size(0), )).to(g_s.device), torch.zeros((g_t.size(0), )).to(g_t.device), )).long() self.domain_discriminator_accuracy = accuracy(d, d_label) if self.entropy_conditioning: raise NotImplementedError("entropy_conditioning") return F.cross_entropy(d, d_label, reduction=self.reduction) class RandomizedMultiLinearMap(nn.Module): """Random multi linear map Given two inputs :math:`f` and :math:`g`, the definition is .. math:: T_{\odot}(f,g) = \dfrac{1}{\sqrt{d}} (R_f f) \odot (R_g g), where :math:`\odot` is element-wise product, :math:`R_f` and :math:`R_g` are random matrices sampled only once and fixed in training. Args: features_dim (int): dimension of input :math:`f` num_classes (int): dimension of input :math:`g` output_dim (int, optional): dimension of output tensor. Default: 1024 Shape: - f: (minibatch, features_dim) - g: (minibatch, num_classes) - Outputs: (minibatch, output_dim) """ def __init__(self, features_dim: int, num_classes: int, output_dim: Optional[int] = 1024): super(RandomizedMultiLinearMap, self).__init__() self.Rf = torch.randn(features_dim, output_dim) self.Rg = torch.randn(num_classes, output_dim) self.output_dim = output_dim def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: f = torch.mm(f, self.Rf.to(f.device)) g = torch.mm(g, self.Rg.to(g.device)) output = torch.mul(f, g) / np.sqrt(float(self.output_dim)) return output class MultiLinearMap(nn.Module): """Multi linear map Shape: - f: (minibatch, F) - g: (minibatch, C) - Outputs: (minibatch, F * C) """ def __init__(self): super(MultiLinearMap, self).__init__() def forward(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor: batch_size = f.size(0) output = torch.bmm(g.unsqueeze(2), f.unsqueeze(1)) return output.view(batch_size, -1) class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) ================================================ FILE: tllib/alignment/coral.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch import torch.nn as nn class CorrelationAlignmentLoss(nn.Module): r"""The `Correlation Alignment Loss` in `Deep CORAL: Correlation Alignment for Deep Domain Adaptation (ECCV 2016) `_. Given source features :math:`f_S` and target features :math:`f_T`, the covariance matrices are given by .. math:: C_S = \frac{1}{n_S-1}(f_S^Tf_S-\frac{1}{n_S}(\textbf{1}^Tf_S)^T(\textbf{1}^Tf_S)) .. math:: C_T = \frac{1}{n_T-1}(f_T^Tf_T-\frac{1}{n_T}(\textbf{1}^Tf_T)^T(\textbf{1}^Tf_T)) where :math:`\textbf{1}` denotes a column vector with all elements equal to 1, :math:`n_S, n_T` denotes number of source and target samples, respectively. We use :math:`d` to denote feature dimension, use :math:`{\Vert\cdot\Vert}^2_F` to denote the squared matrix `Frobenius norm`. The correlation alignment loss is given by .. math:: l_{CORAL} = \frac{1}{4d^2}\Vert C_S-C_T \Vert^2_F Inputs: - f_s (tensor): feature representations on source domain, :math:`f^s` - f_t (tensor): feature representations on target domain, :math:`f^t` Shape: - f_s, f_t: :math:`(N, d)` where d means the dimension of input features, :math:`N=n_S=n_T` is mini-batch size. - Outputs: scalar. """ def __init__(self): super(CorrelationAlignmentLoss, self).__init__() def forward(self, f_s: torch.Tensor, f_t: torch.Tensor) -> torch.Tensor: mean_s = f_s.mean(0, keepdim=True) mean_t = f_t.mean(0, keepdim=True) cent_s = f_s - mean_s cent_t = f_t - mean_t cov_s = torch.mm(cent_s.t(), cent_s) / (len(f_s) - 1) cov_t = torch.mm(cent_t.t(), cent_t) / (len(f_t) - 1) mean_diff = (mean_s - mean_t).pow(2).mean() cov_diff = (cov_s - cov_t).pow(2).mean() return mean_diff + cov_diff ================================================ FILE: tllib/alignment/d_adapt/__init__.py ================================================ ================================================ FILE: tllib/alignment/d_adapt/feedback.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import itertools import numpy as np import copy import logging from typing import List, Optional, Union import torch from detectron2.config import configurable from detectron2.structures import BoxMode, Boxes, Instances from detectron2.data.catalog import DatasetCatalog, MetadataCatalog from detectron2.data.build import filter_images_with_only_crowd_annotations, filter_images_with_few_keypoints, \ print_instances_class_histogram from detectron2.data.detection_utils import check_metadata_consistency import detectron2.data.transforms as T import detectron2.data.detection_utils as utils from .proposal import Proposal def load_feedbacks_into_dataset(dataset_dicts, proposals_list: List[Proposal]): """ Load precomputed object feedbacks into the dataset. Args: dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. proposals_list (list[Proposal]): list of Proposal. Returns: list[dict]: the same format as dataset_dicts, but added feedback field. """ feedbacks = {} for record in dataset_dicts: image_id = str(record["image_id"]) feedbacks[image_id] = { 'pred_boxes': [], 'pred_classes': [], } for proposals in proposals_list: image_id = str(proposals.image_id) feedbacks[image_id]['pred_boxes'] += proposals.pred_boxes.tolist() feedbacks[image_id]['pred_classes'] += proposals.pred_classes.tolist() # Assuming default bbox_mode of precomputed feedbacks are 'XYXY_ABS' bbox_mode = BoxMode.XYXY_ABS dataset_dicts_with_feedbacks = [] for record in dataset_dicts: # Get the index of the feedback image_id = str(record["image_id"]) record["feedback_proposal_boxes"] = feedbacks[image_id]["pred_boxes"] record["feedback_gt_classes"] = feedbacks[image_id]["pred_classes"] record["feedback_gt_boxes"] = feedbacks[image_id]["pred_boxes"] record["feedback_bbox_mode"] = bbox_mode if sum(map(lambda x: x >= 0, feedbacks[image_id]["pred_classes"])) > 0: # remove images without feedbacks dataset_dicts_with_feedbacks.append(record) return dataset_dicts_with_feedbacks def get_detection_dataset_dicts(names, filter_empty=True, min_keypoints=0, proposals_list=None): """ Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. Args: names (str or list[str]): a dataset name or a list of dataset names filter_empty (bool): whether to filter out images without instance annotations min_keypoints (int): filter out images with fewer keypoints than `min_keypoints`. Set to 0 to do nothing. proposals_list (optional, list[Proposal]): list of Proposal. Returns: list[dict]: a list of dicts following the standard dataset dict format. """ if isinstance(names, str): names = [names] assert len(names), names dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names] for dataset_name, dicts in zip(names, dataset_dicts): assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) if proposals_list is not None: # load precomputed feedbacks for each proposals dataset_dicts = load_feedbacks_into_dataset(dataset_dicts, proposals_list) has_instances = "annotations" in dataset_dicts[0] if filter_empty and has_instances: dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) if min_keypoints > 0 and has_instances: dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) if has_instances: try: class_names = MetadataCatalog.get(names[0]).thing_classes check_metadata_consistency("thing_classes", names) print_instances_class_histogram(dataset_dicts, class_names) except AttributeError: # class names are not available for this dataset pass assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) return dataset_dicts def transform_feedbacks(dataset_dict, image_shape, transforms, *, min_box_size=0): """ Apply transformations to the feedbacks in dataset_dict, if any. Args: dataset_dict (dict): a dict read from the dataset, possibly contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode" image_shape (tuple): height, width transforms (TransformList): min_box_size (int): proposals with either side smaller than this threshold are removed The input dict is modified in-place, with abovementioned keys removed. A new key "proposals" will be added. Its value is an `Instances` object which contains the transformed proposals in its field "proposal_boxes" and "objectness_logits". """ if "feedback_proposal_boxes" in dataset_dict: # Transform proposal boxes proposal_boxes = transforms.apply_box( BoxMode.convert( dataset_dict.pop("feedback_proposal_boxes"), dataset_dict.get("feedback_bbox_mode"), BoxMode.XYXY_ABS, ) ) proposal_boxes = Boxes(proposal_boxes) gt_boxes = transforms.apply_box( BoxMode.convert( dataset_dict.pop("feedback_gt_boxes"), dataset_dict.get("feedback_bbox_mode"), BoxMode.XYXY_ABS, ) ) gt_boxes = Boxes(gt_boxes) gt_classes = torch.as_tensor( dataset_dict.pop("feedback_gt_classes") ) proposal_boxes.clip(image_shape) gt_boxes.clip(image_shape) keep = proposal_boxes.nonempty(threshold=min_box_size) & (gt_classes >= 0) # keep = boxes.nonempty(threshold=min_box_size) proposal_boxes = proposal_boxes[keep] gt_boxes = gt_boxes[keep] gt_classes = gt_classes[keep] feedbacks = Instances(image_shape) feedbacks.proposal_boxes = proposal_boxes feedbacks.gt_boxes = gt_boxes feedbacks.gt_classes = gt_classes dataset_dict["feedbacks"] = feedbacks class DatasetMapper: """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model. This is the default callable to be used to map your dataset dict into training data. You may need to follow it to implement your own one for customized logic, such as a different way to read or transform images. See :doc:`/tutorials/data_loading` for details. The callable currently does the following: 1. Read the image from "file_name" 2. Applies cropping/geometric transforms to the image and annotations 3. Prepare data and annotations to Tensor and :class:`Instances` """ @configurable def __init__( self, is_train: bool, *, augmentations: List[Union[T.Augmentation, T.Transform]], image_format: str, use_instance_mask: bool = False, use_keypoint: bool = False, instance_mask_format: str = "polygon", keypoint_hflip_indices: Optional[np.ndarray] = None, precomputed_proposal_topk: Optional[int] = None, recompute_boxes: bool = False, ): """ NOTE: this interface is experimental. Args: is_train: whether it's used in training or inference augmentations: a list of augmentations or deterministic transforms to apply image_format: an image format supported by :func:`detection_utils.read_image`. use_instance_mask: whether to process instance segmentation annotations, if available use_keypoint: whether to process keypoint annotations if available instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation masks into this format. keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` precomputed_proposal_topk: if given, will load pre-computed proposals from dataset_dict and keep the top k proposals for each image. recompute_boxes: whether to overwrite bounding box annotations by computing tight bounding boxes from instance mask annotations. """ if recompute_boxes: assert use_instance_mask, "recompute_boxes requires instance masks" # fmt: off self.is_train = is_train self.augmentations = T.AugmentationList(augmentations) self.image_format = image_format self.use_instance_mask = use_instance_mask self.instance_mask_format = instance_mask_format self.use_keypoint = use_keypoint self.keypoint_hflip_indices = keypoint_hflip_indices self.proposal_topk = precomputed_proposal_topk self.recompute_boxes = recompute_boxes # fmt: on logger = logging.getLogger(__name__) mode = "training" if is_train else "inference" logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") @classmethod def from_config(cls, cfg, is_train: bool = True): augs = utils.build_augmentation(cfg, is_train) if cfg.INPUT.CROP.ENABLED and is_train: augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) recompute_boxes = cfg.MODEL.MASK_ON else: recompute_boxes = False ret = { "is_train": is_train, "augmentations": augs, "image_format": cfg.INPUT.FORMAT, "use_instance_mask": cfg.MODEL.MASK_ON, "instance_mask_format": cfg.INPUT.MASK_FORMAT, "use_keypoint": cfg.MODEL.KEYPOINT_ON, "recompute_boxes": recompute_boxes, } if cfg.MODEL.KEYPOINT_ON: ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) if cfg.MODEL.LOAD_PROPOSALS: ret["precomputed_proposal_topk"] = ( cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN if is_train else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST ) return ret def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below # USER: Write your own image loading if it's not from a file image = utils.read_image(dataset_dict["file_name"], format=self.image_format) utils.check_image_size(dataset_dict, image) # USER: Remove if you don't do semantic/panoptic segmentation. if "sem_seg_file_name" in dataset_dict: sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) else: sem_seg_gt = None aug_input = T.AugInput(image, sem_seg=sem_seg_gt) transforms = self.augmentations(aug_input) image, sem_seg_gt = aug_input.image, aug_input.sem_seg image_shape = image.shape[:2] # h, w # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if sem_seg_gt is not None: dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) # USER: Remove if you don't use pre-computed proposals. # Most users would not need this feature. if self.proposal_topk is not None: utils.transform_proposals( dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk ) transform_feedbacks( dataset_dict, image_shape, transforms ) if not self.is_train: # USER: Modify this if you want to keep them for some reason. dataset_dict.pop("annotations", None) dataset_dict.pop("sem_seg_file_name", None) return dataset_dict if "annotations" in dataset_dict: # USER: Modify this if you want to keep them for some reason. for anno in dataset_dict["annotations"]: if not self.use_instance_mask: anno.pop("segmentation", None) if not self.use_keypoint: anno.pop("keypoints", None) # USER: Implement additional transformations if you have other types of data annos = [ utils.transform_instance_annotations( obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices ) for obj in dataset_dict.pop("annotations") if obj.get("iscrowd", 0) == 0 ] instances = utils.annotations_to_instances( annos, image_shape, mask_format=self.instance_mask_format ) # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. if self.recompute_boxes: instances.gt_boxes = instances.gt_masks.get_bounding_boxes() dataset_dict["instances"] = utils.filter_empty_instances(instances) return dataset_dict ================================================ FILE: tllib/alignment/d_adapt/modeling/__init__.py ================================================ from . import meta_arch from . import roi_heads ================================================ FILE: tllib/alignment/d_adapt/modeling/matcher.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch from torch import Tensor, nn from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple class MaxOverlapMatcher(object): """ This class assigns to each predicted "element" (e.g., a box) a ground-truth element. Each predicted element will have exactly zero or one matches; each ground-truth element may be matched to one predicted elements. """ def __init__(self): pass def __call__(self, match_quality_matrix): """ Args: match_quality_matrix (Tensor[float]): an MxN tensor, containing the pairwise quality between M ground-truth elements and N predicted elements. All elements must be >= 0 (due to the us of `torch.nonzero` for selecting indices in :meth:`set_low_quality_matches_`). Returns: matches (Tensor[int64]): a vector of length N, where matches[i] is a matched ground-truth index in [0, M) match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates whether a prediction is a true or false positive or ignored """ assert match_quality_matrix.dim() == 2 # match_quality_matrix is M (gt) x N (predicted) # Max over gt elements (dim 0) to find best gt candidate for each prediction _, matched_idxs = match_quality_matrix.max(dim=0) anchor_labels = match_quality_matrix.new_full( (match_quality_matrix.size(1),), -1, dtype=torch.int8 ) # For each gt, find the prediction with which it has highest quality highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) # Find the highest quality match available, even if it is low, including ties. # Note that the matches qualities must be positive due to the use of # `torch.nonzero`. _, pred_inds_with_highest_quality = nonzero_tuple( match_quality_matrix == highest_quality_foreach_gt[:, None] ) anchor_labels[pred_inds_with_highest_quality] = 1 return matched_idxs, anchor_labels ================================================ FILE: tllib/alignment/d_adapt/modeling/meta_arch/__init__.py ================================================ from .rcnn import DecoupledGeneralizedRCNN from .retinanet import DecoupledRetinaNet ================================================ FILE: tllib/alignment/d_adapt/modeling/meta_arch/rcnn.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch from typing import Optional, Callable, Tuple, Any, List, Sequence, Dict import numpy as np from detectron2.utils.events import get_event_storage from detectron2.structures import Instances from detectron2.data.detection_utils import convert_image_to_rgb from detectron2.modeling.postprocessing import detector_postprocess from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY from tllib.vision.models.object_detection.meta_arch import TLGeneralizedRCNN @META_ARCH_REGISTRY.register() class DecoupledGeneralizedRCNN(TLGeneralizedRCNN): """ Generalized R-CNN for Decoupled Adaptation (D-adapt). Similar to that in in Supervised Learning, DecoupledGeneralizedRCNN has the following three components: 1. Per-image feature extraction (aka backbone) 2. Region proposal generation 3. Per-region feature extraction and prediction Different from that in Supervised Learning, DecoupledGeneralizedRCNN 1. accepts unlabeled images and uses the feedbacks from adaptors as supervision during training 2. generate foreground and background proposals during inference Args: backbone: a backbone module, must follow detectron2's backbone interface proposal_generator: a module that generates proposals using backbone features roi_heads: a ROI head that performs per-region computation pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image input_format: describe the meaning of channels of input. Needed by visualization vis_period: the period to run visualization. Set to 0 to disable. finetune (bool): whether finetune the detector or train from scratch. Default: True Inputs: - batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * image: Tensor, image in (C, H, W) format. * instances (optional): groundtruth :class:`Instances` * feedbacks (optional): :class:`Instances`, feedbacks from adaptors. * "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. - labeled (bool, optional): whether has ground-truth label Outputs: - outputs (during inference): A list of dict where each dict is the output for one input image. The dict contains a key "instances" whose value is a :class:`Instances`. The :class:`Instances` object has the following keys: "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" - losses (during training): A dict of different losses """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]], labeled=True): if not self.training: return self.inference(batched_inputs) images = self.preprocess_image(batched_inputs) if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] else: gt_instances = None features = self.backbone(images.tensor) if "feedbacks" in batched_inputs[0]: feedbacks = [x["feedbacks"].to(self.device) for x in batched_inputs] else: feedbacks = None proposals, proposal_losses = self.proposal_generator(images, features, gt_instances, labeled=labeled) _, _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, feedbacks, labeled=labeled) losses = {} losses.update(detector_losses) losses.update(proposal_losses) if self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: self.visualize_training(batched_inputs, proposals, feedbacks) return losses def visualize_training(self, batched_inputs, proposals, feedbacks=None): """ A function used to visualize images and proposals. It shows ground truth bounding boxes on the original image and up to 20 top-scoring predicted object proposals on the original image. Users can implement different visualization functions for different models. Args: batched_inputs (list): a list that contains input to the model. proposals (list): a list that contains predicted proposals. Both batched_inputs and proposals should have the same length. feedbacks (list): a list that contains feedbacks from adaptors. Both batched_inputs and feedbacks should have the same length. """ from detectron2.utils.visualizer import Visualizer storage = get_event_storage() max_vis_prop = 20 for input, prop in zip(batched_inputs, proposals): img = input["image"] img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format) v_gt = Visualizer(img, None) v_gt = v_gt.overlay_instances(boxes=input["instances"].gt_boxes) anno_img = v_gt.get_image() box_size = min(len(prop.proposal_boxes), max_vis_prop) v_pred = Visualizer(img, None) v_pred = v_pred.overlay_instances( boxes=prop.proposal_boxes[0:box_size].tensor.cpu().numpy() ) prop_img = v_pred.get_image() num_classes = self.roi_heads.box_predictor.num_classes if feedbacks is not None: v_feedback_gt = Visualizer(img, None) instance = feedbacks[0].to(torch.device("cpu")) v_feedback_gt = v_feedback_gt.overlay_instances( boxes=instance.proposal_boxes[instance.gt_classes != num_classes]) feedback_gt_img = v_feedback_gt.get_image() v_feedback_gf = Visualizer(img, None) v_feedback_gf = v_feedback_gf.overlay_instances( boxes=instance.proposal_boxes[instance.gt_classes == num_classes]) feedback_gf_img = v_feedback_gf.get_image() vis_img = np.vstack((anno_img, prop_img, feedback_gt_img, feedback_gf_img)) vis_img = vis_img.transpose(2, 0, 1) vis_name = f"Top: GT; Middle: Pred; Bottom: Feedback GT, Feedback GF" else: vis_img = np.concatenate((anno_img, prop_img), axis=1) vis_img = vis_img.transpose(2, 0, 1) vis_name = "Left: GT bounding boxes; Right: Predicted proposals" storage.put_image(vis_name, vis_img) break # only visualize one image in a batch def inference( self, batched_inputs: Tuple[Dict[str, torch.Tensor]], detected_instances: Optional[List[Instances]] = None, do_postprocess: bool = True, ): """ Run inference on the given inputs. Args: batched_inputs (list[dict]): same as in :meth:`forward` detected_instances (None or list[Instances]): if not None, it contains an `Instances` object per image. The `Instances` object contains "pred_boxes" and "pred_classes" which are known boxes in the image. The inference will then skip the detection of bounding boxes, and only predict other per-ROI outputs. do_postprocess (bool): whether to apply post-processing on the outputs. Returns: When do_postprocess=True, same as in :meth:`forward`. Otherwise, a list[Instances] containing raw network outputs. """ assert not self.training images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) proposals, _ = self.proposal_generator(images, features, None) results, background_results, _ = self.roi_heads(images, features, proposals, None) processed_results = [] for results_per_image, background_results_per_image, input_per_image, image_size in zip( results, background_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) r = detector_postprocess(results_per_image, height, width) background_r = detector_postprocess(background_results_per_image, height, width) processed_results.append({"instances": r, 'background': background_r}) return processed_results ================================================ FILE: tllib/alignment/d_adapt/modeling/meta_arch/retinanet.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, Callable, Tuple, Any, List, Sequence, Dict import random import numpy as np import torch from torch import Tensor from detectron2.structures import BoxMode, Boxes, Instances, pairwise_iou, ImageList from detectron2.layers import ShapeSpec, batched_nms, cat, get_norm, nonzero_tuple from detectron2.modeling import detector_postprocess from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY from detectron2.data.detection_utils import convert_image_to_rgb from detectron2.utils.events import get_event_storage from tllib.vision.models.object_detection.meta_arch import TLRetinaNet from ..matcher import MaxOverlapMatcher @META_ARCH_REGISTRY.register() class DecoupledRetinaNet(TLRetinaNet): """ RetinaNet for Decoupled Adaptation (D-adapt). Different from that in Supervised Learning, DecoupledRetinaNet 1. accepts unlabeled images and uses the feedbacks from adaptors as supervision during training 2. generate foreground and background proposals during inference Args: backbone: a backbone module, must follow detectron2's backbone interface head (nn.Module): a module that predicts logits and regression deltas for each level from a list of per-level features head_in_features (Tuple[str]): Names of the input feature maps to be used in head anchor_generator (nn.Module): a module that creates anchors from a list of features. Usually an instance of :class:`AnchorGenerator` box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to instance boxes anchor_matcher (Matcher): label the anchors by matching them with ground truth. num_classes (int): number of classes. Used to label background proposals. # Loss parameters: focal_loss_alpha (float): focal_loss_alpha focal_loss_gamma (float): focal_loss_gamma smooth_l1_beta (float): smooth_l1_beta box_reg_loss_type (str): Options are "smooth_l1", "giou" # Inference parameters: test_score_thresh (float): Inference cls score threshold, only anchors with score > INFERENCE_TH are considered for inference (to improve speed) test_topk_candidates (int): Select topk candidates before NMS test_nms_thresh (float): Overlap threshold used for non-maximum suppression (suppress boxes with IoU >= this threshold) max_detections_per_image (int): Maximum number of detections to return per image during inference (100 is based on the limit established for the COCO dataset). # Input parameters pixel_mean (Tuple[float]): Values to be used for image normalization (BGR order). To train on images of different number of channels, set different mean & std. Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675] pixel_std (Tuple[float]): When using pre-trained models in Detectron1 or any MSRA models, std has been absorbed into its conv1 weights, so the std needs to be set 1. Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) vis_period (int): The period (in terms of steps) for minibatch visualization at train time. Set to 0 to disable. input_format (str): Whether the model needs RGB, YUV, HSV etc. finetune (bool): whether finetune the detector or train from scratch. Default: True Inputs: - batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * image: Tensor, image in (C, H, W) format. * instances (optional): groundtruth :class:`Instances` * "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. - labeled (bool, optional): whether has ground-truth label Outputs: - outputs: A list of dict where each dict is the output for one input image. The dict contains a key "instances" whose value is a :class:`Instances` and a key "features" whose value is the features of middle layers. The :class:`Instances` object has the following keys: "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" - losses: A dict of different losses """ def __init__(self, *args, max_samples_per_level=25, **kwargs): super(DecoupledRetinaNet, self).__init__(*args, **kwargs) self.max_samples_per_level = max_samples_per_level self.max_matcher = MaxOverlapMatcher() def forward_training(self, images, features, predictions, gt_instances=None, feedbacks=None, labeled=True): # Transpose the Hi*Wi*A dimension to the middle: pred_logits, pred_anchor_deltas = self._transpose_dense_predictions( predictions, [self.num_classes, 4] ) anchors = self.anchor_generator(features) if labeled: gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances) losses = self.losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes) else: proposal_labels, proposal_boxes = self.label_pseudo_anchors(anchors, feedbacks) losses = self.losses(anchors, pred_logits, proposal_labels, pred_anchor_deltas, proposal_boxes) losses.pop('loss_box_reg') return losses def forward(self, batched_inputs: Tuple[Dict[str, Tensor]], labeled=True): images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) features = [features[f] for f in self.head_in_features] predictions = self.head(features) if self.training: if "instances" in batched_inputs[0]: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] else: gt_instances = None if "feedbacks" in batched_inputs[0]: feedbacks = [x["feedbacks"].to(self.device) for x in batched_inputs] else: feedbacks = None losses = self.forward_training(images, features, predictions, gt_instances, feedbacks, labeled) if self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: results = self.forward_inference(images, features, predictions) self.visualize_training(batched_inputs, results, feedbacks) return losses else: # sample_background must be called before inference # since inference will change predictions background_results = self.sample_background(images, features, predictions) results = self.forward_inference(images, features, predictions) processed_results = [] for results_per_image, background_results_per_image, input_per_image, image_size in zip( results, background_results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) r = detector_postprocess(results_per_image, height, width) background_r = detector_postprocess(background_results_per_image, height, width) processed_results.append({"instances": r, "background": background_r}) return processed_results @torch.no_grad() def label_pseudo_anchors(self, anchors, instances): """ Args: anchors (list[Boxes]): A list of #feature level Boxes. The Boxes contains anchors of this image on the specific feature level. instances (list[Instances]): a list of N `Instances`s. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Returns: list[Tensor]: List of #img tensors. i-th element is a vector of labels whose length is the total number of anchors across all feature maps (sum(Hi * Wi * A)). Label values are in {-1, 0, ..., K}, with -1 means ignore, and K means background. list[Tensor]: i-th element is a Rx4 tensor, where R is the total number of anchors across feature maps. The values are the matched gt boxes for each anchor. Values are undefined for those anchors not labeled as foreground. """ anchors = Boxes.cat(anchors) # Rx4 gt_labels = [] matched_gt_boxes = [] for gt_per_image in instances: match_quality_matrix = pairwise_iou(gt_per_image.gt_boxes, anchors) matched_idxs, anchor_labels = self.max_matcher(match_quality_matrix) del match_quality_matrix if len(gt_per_image) > 0: matched_gt_boxes_i = gt_per_image.gt_boxes.tensor[matched_idxs] gt_labels_i = gt_per_image.gt_classes[matched_idxs] # Anchors with label -1 are ignored. gt_labels_i[anchor_labels == -1] = -1 else: matched_gt_boxes_i = torch.zeros_like(anchors.tensor) gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes gt_labels.append(gt_labels_i) matched_gt_boxes.append(matched_gt_boxes_i) return gt_labels, matched_gt_boxes def sample_background( self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]] ): pred_logits, pred_anchor_deltas = self._transpose_dense_predictions( predictions, [self.num_classes, 4] ) anchors = self.anchor_generator(features) results: List[Instances] = [] for img_idx, image_size in enumerate(images.image_sizes): scores_per_image = [x[img_idx].sigmoid() for x in pred_logits] deltas_per_image = [x[img_idx] for x in pred_anchor_deltas] results_per_image = self.sample_background_single_image( anchors, scores_per_image, deltas_per_image, image_size ) results.append(results_per_image) return results def sample_background_single_image( self, anchors: List[Boxes], box_cls: List[Tensor], box_delta: List[Tensor], image_size: Tuple[int, int], ): boxes_all = [] scores_all = [] # Iterate over every feature level for box_cls_i, box_reg_i, anchors_i in zip(box_cls, box_delta, anchors): # (HxWxAxK,) predicted_prob = box_cls_i.max(dim=1).values # 1. Keep boxes with confidence score lower than threshold keep_idxs = predicted_prob < self.test_score_thresh anchor_idxs = nonzero_tuple(keep_idxs)[0] # 2. Random sample boxes anchor_idxs = anchor_idxs[ random.sample(range(len(anchor_idxs)), k=min(len(anchor_idxs), self.max_samples_per_level))] predicted_prob = predicted_prob[anchor_idxs] anchors_i = anchors_i[anchor_idxs] boxes_all.append(anchors_i.tensor) scores_all.append(predicted_prob) boxes_all, scores_all = [ cat(x) for x in [boxes_all, scores_all] ] result = Instances(image_size) result.pred_boxes = Boxes(boxes_all) result.scores = 1. - scores_all # the confidence score to be background result.pred_classes = torch.tensor([self.num_classes for _ in range(len(scores_all))]) return result def visualize_training(self, batched_inputs, results, feedbacks=None): """ A function used to visualize ground truth images and final network predictions. It shows ground truth bounding boxes on the original image and up to 20 predicted object bounding boxes on the original image. Args: batched_inputs (list): a list that contains input to the model. results (List[Instances]): a list of #images elements returned by forward_inference(). """ from detectron2.utils.visualizer import Visualizer assert len(batched_inputs) == len( results ), "Cannot visualize inputs and results of different sizes" storage = get_event_storage() max_boxes = 20 image_index = 0 # only visualize a single image img = batched_inputs[image_index]["image"] img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format) v_gt = Visualizer(img, None) v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes) anno_img = v_gt.get_image() processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1]) predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy() v_pred = Visualizer(img, None) v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes]) prop_img = v_pred.get_image() num_classes = self.num_classes if feedbacks is not None: v_feedback_gt = Visualizer(img, None) instance = feedbacks[0].to(torch.device("cpu")) v_feedback_gt = v_feedback_gt.overlay_instances( boxes=instance.proposal_boxes[instance.gt_classes != num_classes]) feedback_gt_img = v_feedback_gt.get_image() v_feedback_gf = Visualizer(img, None) v_feedback_gf = v_feedback_gf.overlay_instances( boxes=instance.proposal_boxes[instance.gt_classes == num_classes]) feedback_gf_img = v_feedback_gf.get_image() vis_img = np.vstack((anno_img, prop_img, feedback_gt_img, feedback_gf_img)) vis_img = vis_img.transpose(2, 0, 1) vis_name = f"Top: GT; Middle: Pred; Bottom: Feedback GT, Feedback GF" else: vis_img = np.concatenate((anno_img, prop_img), axis=1) vis_img = vis_img.transpose(2, 0, 1) vis_name = "Left: GT bounding boxes; Right: Predicted proposals" storage.put_image(vis_name, vis_img) ================================================ FILE: tllib/alignment/d_adapt/modeling/roi_heads/__init__.py ================================================ from .roi_heads import DecoupledRes5ROIHeads ================================================ FILE: tllib/alignment/d_adapt/modeling/roi_heads/fast_rcnn.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Dict from detectron2.layers import cat from detectron2.modeling.roi_heads.fast_rcnn import ( _log_classification_stats, FastRCNNOutputLayers ) from detectron2.structures import Instances from tllib.modules.loss import LabelSmoothSoftmaxCEV1 import torch def label_smoothing_cross_entropy(input, target, *, reduction="mean", **kwargs): """ Same as `tllib.modules.loss.LabelSmoothSoftmaxCEV1`, but returns 0 (instead of nan) for empty inputs. """ if target.numel() == 0 and reduction == "mean": return input.sum() * 0.0 # connect the gradient return LabelSmoothSoftmaxCEV1(reduction=reduction, **kwargs)(input, target) class DecoupledFastRCNNOutputLayers(FastRCNNOutputLayers): """ Two linear layers for predicting Fast R-CNN outputs: 1. proposal-to-detection box regression deltas 2. classification scores Replace cross-entropy with label-smoothing cross-entropy """ def losses(self, predictions, proposals): """ Args: predictions: return values of :meth:`forward()`. proposals (list[Instances]): proposals that match the features that were used to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``, ``gt_classes`` are expected. Returns: Dict[str, Tensor]: dict of losses """ scores, proposal_deltas = predictions # parse classification outputs gt_classes = ( cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) ) _log_classification_stats(scores, gt_classes) # parse box regression outputs if len(proposals): proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4 assert not proposal_boxes.requires_grad, "Proposals should not require gradients!" # If "gt_boxes" does not exist, the proposals must be all negative and # should not be included in regression loss computation. # Here we just use proposal_boxes as an arbitrary placeholder because its # value won't be used in self.box_reg_loss(). gt_boxes = cat( [(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals], dim=0, ) else: proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) losses = { "loss_cls": label_smoothing_cross_entropy(scores, gt_classes, reduction="mean"), "loss_box_reg": self.box_reg_loss( proposal_boxes, gt_boxes, proposal_deltas, gt_classes ), } return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()} ================================================ FILE: tllib/alignment/d_adapt/modeling/roi_heads/roi_heads.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch import numpy as np import random from typing import List, Tuple, Dict from detectron2.structures import Boxes, Instances from detectron2.utils.events import get_event_storage from detectron2.layers import ShapeSpec, batched_nms from detectron2.modeling.roi_heads import ( ROI_HEADS_REGISTRY, Res5ROIHeads, StandardROIHeads ) from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference from detectron2.modeling.sampling import subsample_labels from detectron2.layers import nonzero_tuple from .fast_rcnn import DecoupledFastRCNNOutputLayers @ROI_HEADS_REGISTRY.register() class DecoupledRes5ROIHeads(Res5ROIHeads): """ The ROIHeads in a typical "C4" R-CNN model, where the box and mask head share the cropping and the per-region feature computation by a Res5 block. It typically contains logic to 1. when training on labeled source domain, match proposals with ground truth and sample them 2. when training on unlabeled target domain, match proposals with feedbacks from adaptors and sample them 3. crop the regions and extract per-region features using proposals 4. make per-region predictions with different heads """ def __init__(self, *args, **kwargs): super(DecoupledRes5ROIHeads, self).__init__(*args, **kwargs) @classmethod def from_config(cls, cfg, input_shape): # fmt: off ret = super().from_config(cfg, input_shape) ret["res5"], out_channels = cls._build_res5_block(cfg) box_predictor = DecoupledFastRCNNOutputLayers(cfg, ShapeSpec(channels=out_channels, height=1, width=1)) ret["box_predictor"] = box_predictor return ret def forward(self, images, features, proposals, targets=None, feedbacks=None, labeled=True): """ Prepare some proposals to be used to train the ROI heads. When training on labeled source domain, it performs box matching between `proposals` and `targets`, and assigns training labels to the proposals. When training on unlabeled target domain, it performs box matching between `proposals` and `feedbacks`, and assigns training labels to the proposals. It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth boxes, with a fraction of positives that is no larger than ``self.positive_fraction``. Args: images (ImageList): features (dict[str,Tensor]): input data as a mapping from feature map name to tensor. Axis 0 represents the number of images `N` in the input data; axes 1-3 are channels, height, and width, which may vary between feature maps (e.g., if a feature pyramid is used). proposals (list[Instances]): length `N` list of `Instances`. The i-th `Instances` contains object proposals for the i-th input image, with fields "proposal_boxes" and "objectness_logits". targets (list[Instances], optional): length `N` list of `Instances`. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. It may have the following fields: - gt_boxes: the bounding box of each instance. - gt_classes: the label for each instance with a category ranging in [0, #class]. - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance. - gt_keypoints: NxKx3, the groud-truth keypoints for each instance. feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th `Instances` contains the feedback of per-instance annotations for the i-th input image. Specify `feedbacks` during training only. It have the same fields as `targets`. labeled (bool, optional): whether has ground-truth label Returns: tuple[list[Instances], list[Instances], dict]: a tuple containing foreground proposals (`Instances`), background proposals (`Instances`) and a dict of different losses. Each `Instances` has the following fields: - proposal_boxes: the proposal boxes - gt_boxes: the ground-truth box that the proposal is assigned to (this is only meaningful if the proposal has a label > 0; if label = 0 then the ground-truth box is random) Other fields such as "gt_classes", "gt_masks", that's included in `targets`. """ del images if self.training: assert targets if labeled: proposals = self.label_and_sample_proposals(proposals, targets) else: proposals = self.label_and_sample_feedbacks(feedbacks) del targets proposal_boxes = [x.proposal_boxes for x in proposals] box_features = self._shared_roi_transform( [features[f] for f in self.in_features], proposal_boxes ) predictions = self.box_predictor(box_features.mean(dim=[2, 3])) if self.training: del features losses = self.box_predictor.losses(predictions, proposals) if not labeled: losses.pop("loss_box_reg") return [], [], losses else: pred_instances, _ = self.box_predictor.inference(predictions, proposals) boxes = self.box_predictor.predict_boxes(predictions, proposals) scores = self.box_predictor.predict_probs(predictions, proposals) image_shapes = [x.image_size for x in proposals] pred_instances, _ = fast_rcnn_inference( boxes, scores, image_shapes, self.box_predictor.test_score_thresh, self.box_predictor.test_nms_thresh, self.box_predictor.test_topk_per_image, ) background_instances, _ = fast_rcnn_sample_background( [box.tensor for box in proposal_boxes], scores, image_shapes, self.box_predictor.test_score_thresh, self.box_predictor.test_nms_thresh, self.box_predictor.test_topk_per_image, ) pred_instances = self.forward_with_given_boxes(features, pred_instances) background_instances = self.forward_with_given_boxes(features, background_instances) return pred_instances, background_instances, {} @torch.no_grad() def label_and_sample_feedbacks( self, feedbacks, batch_size_per_image=256 ) -> List[Instances]: """ Prepare some proposals to be used to train the ROI heads. It performs box matching between `proposals` and `feedbacks`, and assigns training labels to the proposals. It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth boxes, with a fraction of positives that is no larger than ``self.positive_fraction``. Args: feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th `Instances` contains the feedback of per-instance annotations for the i-th input image. Specify `feedbacks` during training only. It have the same fields as `targets`. Returns: list[Instances]: length `N` list of `Instances`s containing the proposals sampled for training. Each `Instances` has the following fields: - proposal_boxes: the proposal boxes - gt_boxes: the ground-truth box that the proposal is assigned to (this is only meaningful if the proposal has a label > 0; if label = 0 then the ground-truth box is random) Other fields such as "gt_classes", "gt_masks", that's included in `targets`. """ proposals_with_gt = [] num_fg_samples = [] num_bg_samples = [] for feedbacks_per_image in feedbacks: gt_classes = feedbacks_per_image.gt_classes positive = nonzero_tuple((gt_classes != -1) & (gt_classes != self.num_classes))[0] # ensure each batch consists the same number bg and fg boxes batch_size = min(batch_size_per_image, max(2 * positive.numel(), 1)) sampled_fg_idxs, sampled_bg_idxs = subsample_labels( gt_classes, batch_size, self.positive_fraction, self.num_classes ) sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) gt_classes = gt_classes[sampled_idxs] # Set target attributes of the sampled proposals: proposals_per_image = feedbacks_per_image[sampled_idxs] proposals_per_image.gt_classes = gt_classes num_bg_samples.append((gt_classes == self.num_classes).sum().item()) num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1]) proposals_with_gt.append(proposals_per_image) # Log the number of fg/bg samples that are selected for training ROI heads storage = get_event_storage() storage.put_scalar("roi_head_pseudo/num_fg_samples", np.mean(num_fg_samples)) storage.put_scalar("roi_head_pseudo/num_bg_samples", np.mean(num_bg_samples)) return proposals_with_gt @ROI_HEADS_REGISTRY.register() class DecoupledStandardROIHeads(StandardROIHeads): """ The Standard ROIHeads used by most models, such as FPN and C5. It's "standard" in a sense that there is no ROI transform sharing or feature sharing between tasks. Each head independently processes the input features by each head's own pooler and head. It typically contains logic to 1. when training on labeled source domain, match proposals with ground truth and sample them 2. when training on unlabeled target domain, match proposals with feedbacks from adaptors and sample them 3. crop the regions and extract per-region features using proposals 4. make per-region predictions with different heads """ def __init__(self, *args, **kwargs): super(DecoupledStandardROIHeads, self).__init__(*args, **kwargs) @classmethod def from_config(cls, cfg, input_shape): # fmt: off ret = super().from_config(cfg, input_shape) box_predictor = DecoupledFastRCNNOutputLayers(cfg, ret['box_head'].output_shape) ret["box_predictor"] = box_predictor return ret def forward(self, images, features, proposals, targets=None, feedbacks=None, labeled=True): """ Prepare some proposals to be used to train the ROI heads. When training on labeled source domain, it performs box matching between `proposals` and `targets`, and assigns training labels to the proposals. When training on unlabeled target domain, it performs box matching between `proposals` and `feedbacks`, and assigns training labels to the proposals. It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth boxes, with a fraction of positives that is no larger than ``self.positive_fraction``. Args: images (ImageList): features (dict[str,Tensor]): input data as a mapping from feature map name to tensor. Axis 0 represents the number of images `N` in the input data; axes 1-3 are channels, height, and width, which may vary between feature maps (e.g., if a feature pyramid is used). proposals (list[Instances]): length `N` list of `Instances`. The i-th `Instances` contains object proposals for the i-th input image, with fields "proposal_boxes" and "objectness_logits". targets (list[Instances], optional): length `N` list of `Instances`. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. It may have the following fields: - gt_boxes: the bounding box of each instance. - gt_classes: the label for each instance with a category ranging in [0, #class]. - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance. - gt_keypoints: NxKx3, the groud-truth keypoints for each instance. feedbacks (list[Instances], optional): length `N` list of `Instances`. The i-th `Instances` contains the feedback of per-instance annotations for the i-th input image. Specify `feedbacks` during training only. It have the same fields as `targets`. labeled (bool, optional): whether has ground-truth label Returns: tuple[list[Instances], list[Instances], dict]: a tuple containing foreground proposals (`Instances`), background proposals (`Instances`) and a dict of different losses. Each `Instances` has the following fields: - proposal_boxes: the proposal boxes - gt_boxes: the ground-truth box that the proposal is assigned to (this is only meaningful if the proposal has a label > 0; if label = 0 then the ground-truth box is random) Other fields such as "gt_classes", "gt_masks", that's included in `targets`. """ del images if self.training: assert targets if labeled: proposals = self.label_and_sample_proposals(proposals, targets) else: proposals = self.label_and_sample_feedbacks(feedbacks) del targets if self.training: losses = self._forward_box(features, proposals) # Usually the original proposals used by the box head are used by the mask, keypoint # heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes # predicted by the box head. losses.update(self._forward_mask(features, proposals)) losses.update(self._forward_keypoint(features, proposals)) if not labeled: losses.pop('loss_box_reg') return [], [], losses else: pred_instances, predictions = self._forward_box(features, proposals) scores = self.box_predictor.predict_probs(predictions, proposals) image_shapes = [x.image_size for x in proposals] proposal_boxes = [x.proposal_boxes for x in proposals] background_instances, _ = fast_rcnn_sample_background( [box.tensor for box in proposal_boxes], scores, image_shapes, self.box_predictor.test_score_thresh, self.box_predictor.test_nms_thresh, self.box_predictor.test_topk_per_image, ) pred_instances = self.forward_with_given_boxes(features, pred_instances) background_instances = self.forward_with_given_boxes(features, background_instances) return pred_instances, background_instances, {} def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]): """ Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`, the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument. Args: features (dict[str, Tensor]): mapping from feature map names to tensor. Same as in :meth:`ROIHeads.forward`. proposals (list[Instances]): the per-image object proposals with their matching ground truth. Each has fields "proposal_boxes", and "objectness_logits", "gt_classes", "gt_boxes". Returns: In training, a dict of losses. In inference, a list of `Instances`, the predicted instances. """ features = [features[f] for f in self.box_in_features] box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals]) box_features = self.box_head(box_features) predictions = self.box_predictor(box_features) del box_features if self.training: losses = self.box_predictor.losses(predictions, proposals) # proposals is modified in-place below, so losses must be computed first. if self.train_on_pred_boxes: with torch.no_grad(): pred_boxes = self.box_predictor.predict_boxes_for_gt_classes( predictions, proposals ) for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes): proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image) return losses else: pred_instances, _ = self.box_predictor.inference(predictions, proposals) return pred_instances, predictions @torch.no_grad() def label_and_sample_feedbacks( self, feedbacks, batch_size_per_image=256 ) -> List[Instances]: """ Prepare some proposals to be used to train the ROI heads. It performs box matching between `proposals` and `targets`, and assigns training labels to the proposals. It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth boxes, with a fraction of positives that is no larger than ``self.positive_fraction``. Args: See :meth:`ROIHeads.forward` Returns: list[Instances]: length `N` list of `Instances`s containing the proposals sampled for training. Each `Instances` has the following fields: - proposal_boxes: the proposal boxes - gt_boxes: the ground-truth box that the proposal is assigned to (this is only meaningful if the proposal has a label > 0; if label = 0 then the ground-truth box is random) Other fields such as "gt_classes", "gt_masks", that's included in `targets`. """ proposals_with_gt = [] num_fg_samples = [] num_bg_samples = [] for feedbacks_per_image in feedbacks: gt_classes = feedbacks_per_image.gt_classes positive = nonzero_tuple((gt_classes != -1) & (gt_classes != self.num_classes))[0] # ensure each batch consists the same number bg and fg boxes batch_size = min(batch_size_per_image, max(2 * positive.numel(), 1)) sampled_fg_idxs, sampled_bg_idxs = subsample_labels( gt_classes, batch_size, self.positive_fraction, self.num_classes ) sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) gt_classes = gt_classes[sampled_idxs] # Set target attributes of the sampled proposals: proposals_per_image = feedbacks_per_image[sampled_idxs] proposals_per_image.gt_classes = gt_classes num_bg_samples.append((gt_classes == self.num_classes).sum().item()) num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1]) proposals_with_gt.append(proposals_per_image) # Log the number of fg/bg samples that are selected for training ROI heads storage = get_event_storage() storage.put_scalar("roi_head_pseudo/num_fg_samples", np.mean(num_fg_samples)) storage.put_scalar("roi_head_pseudo/num_bg_samples", np.mean(num_bg_samples)) return proposals_with_gt def fast_rcnn_sample_background( boxes: List[torch.Tensor], scores: List[torch.Tensor], image_shapes: List[Tuple[int, int]], score_thresh: float, nms_thresh: float, topk_per_image: int, ): """ Call `fast_rcnn_sample_background_single_image` for all images. Args: boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic boxes for each image. Element i has shape (Ri, K * 4) if doing class-specific regression, or (Ri, 4) if doing class-agnostic regression, where Ri is the number of predicted objects for image i. This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`. scores (list[Tensor]): A list of Tensors of predicted class scores for each image. Element i has shape (Ri, K + 1), where Ri is the number of predicted objects for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`. image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch. score_thresh (float): Only return detections with a confidence score exceeding this threshold. nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. topk_per_image (int): The number of top scoring detections to return. Set < 0 to return all detections. Returns: instances: (list[Instances]): A list of N instances, one for each image in the batch, that stores the background proposals. kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates the corresponding boxes/scores index in [0, Ri) from the input, for image i. """ result_per_image = [ fast_rcnn_sample_background_single_image( boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image ) for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes) ] return [x[0] for x in result_per_image], [x[1] for x in result_per_image] def fast_rcnn_sample_background_single_image( boxes, scores, image_shape: Tuple[int, int], score_thresh: float, nms_thresh: float, topk_per_image: int, ): """ Single-image background samples. . Args: Same as `fast_rcnn_sample_background`, but with boxes, scores, and image shapes per image. Returns: Same as `fast_rcnn_sample_background`, but for only one image. """ valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1) if not valid_mask.all(): boxes = boxes[valid_mask] scores = scores[valid_mask] num_classes = scores.shape[1] # Only keep background proposals scores = scores[:, -1:] # Convert to Boxes to use the `clip` function ... boxes = Boxes(boxes.reshape(-1, 4)) boxes.clip(image_shape) boxes = boxes.tensor.view(-1, 1, 4) # R x C x 4 # 1. Filter results based on detection scores. It can make NMS more efficient # by filtering out low-confidence detections. filter_mask = scores > score_thresh # R # R' x 2. First column contains indices of the R predictions; # Second column contains indices of classes. filter_inds = filter_mask.nonzero() boxes = boxes[filter_mask] scores = scores[filter_mask] # 2. Apply NMS only for background class keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh) if 0 <= topk_per_image < len(keep): idx = list(range(len(keep))) idx = random.sample(idx, k=topk_per_image) idx = sorted(idx) keep = keep[idx] boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep] result = Instances(image_shape) result.pred_boxes = Boxes(boxes) result.scores = scores result.pred_classes = filter_inds[:, 1] + num_classes - 1 return result, filter_inds[:, 0] ================================================ FILE: tllib/alignment/d_adapt/proposal.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch import copy import numpy as np import os import json from typing import Optional, Callable, List import random import pprint import torchvision.datasets as datasets from torchvision.datasets.folder import default_loader from torchvision.transforms.functional import crop from detectron2.structures import pairwise_iou from detectron2.evaluation.evaluator import DatasetEvaluator from detectron2.data.dataset_mapper import DatasetMapper import detectron2.data.detection_utils as utils import detectron2.data.transforms as T class ProposalMapper(DatasetMapper): """ A callable which takes a dataset dict in Detectron2 Dataset format, and map it into a format used by the model. This is the default callable to be used to map your dataset dict into training data. You may need to follow it to implement your own one for customized logic, such as a different way to read or transform images. See :doc:`/tutorials/data_loading` for details. The callable currently does the following: 1. Read the image from "file_name" 2. Prepare data and annotations to Tensor and :class:`Instances` """ def __call__(self, dataset_dict): """ Args: dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. Returns: dict: a format that builtin models in detectron2 accept """ dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below # USER: Write your own image loading if it's not from a file image = utils.read_image(dataset_dict["file_name"], format=self.image_format) utils.check_image_size(dataset_dict, image) origin_image_shape = image.shape[:2] # h, w aug_input = T.AugInput(image) image = aug_input.image # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, # but not efficient on large generic data structures due to the use of pickle & mp.Queue. # Therefore it's important to use torch.Tensor. dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) if "annotations" in dataset_dict: # USER: Modify this if you want to keep them for some reason. for anno in dataset_dict["annotations"]: if not self.use_instance_mask: anno.pop("segmentation", None) if not self.use_keypoint: anno.pop("keypoints", None) # USER: Implement additional transformations if you have other types of data annos = [ obj for obj in dataset_dict.pop("annotations") if obj.get("iscrowd", 0) == 0 ] instances = utils.annotations_to_instances( annos, origin_image_shape, mask_format=self.instance_mask_format ) # After transforms such as cropping are applied, the bounding box may no longer # tightly bound the object. As an example, imagine a triangle object # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to # the intersection of original bounding box and the cropping box. if self.recompute_boxes: instances.gt_boxes = instances.gt_masks.get_bounding_boxes() dataset_dict["instances"] = utils.filter_empty_instances(instances) return dataset_dict class ProposalGenerator(DatasetEvaluator): """ The function :func:`inference_on_dataset` runs the model over all samples in the dataset, and have a ProposalGenerator to generate proposals for each inputs/outputs. This class will accumulate information of the inputs/outputs (by :meth:`process`), and generate proposals results in the end (by :meth:`evaluate`). """ def __init__(self, iou_threshold=(0.4, 0.5), num_classes=20, *args, **kwargs): super(ProposalGenerator, self).__init__(*args, **kwargs) self.fg_proposal_list = [] self.bg_proposal_list = [] self.iou_threshold = iou_threshold self.num_classes = num_classes def process_type(self, inputs, outputs, type='instances'): cpu_device = torch.device('cpu') input_instance = inputs[0]['instances'].to(cpu_device) output_instance = outputs[0][type].to(cpu_device) filename = inputs[0]['file_name'] pred_boxes = output_instance.pred_boxes pred_scores = output_instance.scores pred_classes = output_instance.pred_classes proposal = Proposal( image_id=inputs[0]['image_id'], filename=filename, pred_boxes=pred_boxes.tensor.numpy(), pred_classes=pred_classes.numpy(), pred_scores=pred_scores.numpy(), ) if hasattr(input_instance, 'gt_boxes'): gt_boxes = input_instance.gt_boxes # assign a gt label for each pred_box if pred_boxes.tensor.shape[0] == 0: proposal.gt_fg_classes = proposal.gt_classes = proposal.gt_ious = proposal.gt_boxes = np.array([]) elif gt_boxes.tensor.shape[0] == 0: proposal.gt_fg_classes = proposal.gt_classes = np.array([self.num_classes for _ in range(pred_boxes.tensor.shape[0])]) proposal.gt_ious = np.array([0. for _ in range(pred_boxes.tensor.shape[0])]) proposal.gt_boxes = np.array([[0, 0, 0, 0] for _ in range(pred_boxes.tensor.shape[0])]) else: gt_ious, gt_classes_idx = pairwise_iou(pred_boxes, gt_boxes).max(dim=1) gt_classes = input_instance.gt_classes[gt_classes_idx] proposal.gt_fg_classes = copy.deepcopy(gt_classes.numpy()) gt_classes[gt_ious <= self.iou_threshold[0]] = self.num_classes # background classes gt_classes[(self.iou_threshold[0] < gt_ious) & (gt_ious <= self.iou_threshold[1])] = -1 # ignore proposal.gt_classes = gt_classes.numpy() proposal.gt_ious = gt_ious.numpy() proposal.gt_boxes = input_instance.gt_boxes[gt_classes_idx].tensor.numpy() return proposal def process(self, inputs, outputs): self.fg_proposal_list.append(self.process_type(inputs, outputs, "instances")) self.bg_proposal_list.append(self.process_type(inputs, outputs, "background")) def evaluate(self): return self.fg_proposal_list, self.bg_proposal_list class Proposal: """ A data structure that stores the proposals for a single image. Args: image_id (str): unique image identifier filename (str): image filename pred_boxes (numpy.ndarray): predicted boxes pred_classes (numpy.ndarray): predicted classes pred_scores (numpy.ndarray): class confidence score gt_classes (numpy.ndarray, optional): ground-truth classes, including background classes gt_boxes (numpy.ndarray, optional): ground-truth boxes gt_ious (numpy.ndarray, optional): IoU between predicted boxes and ground-truth boxes gt_fg_classes (numpy.ndarray, optional): ground-truth foreground classes, not including background classes """ def __init__(self, image_id, filename, pred_boxes, pred_classes, pred_scores, gt_classes=None, gt_boxes=None, gt_ious=None, gt_fg_classes=None): self.image_id = image_id self.filename = filename self.pred_boxes = pred_boxes self.pred_classes = pred_classes self.pred_scores = pred_scores self.gt_classes = gt_classes self.gt_boxes = gt_boxes self.gt_ious = gt_ious self.gt_fg_classes = gt_fg_classes def to_dict(self): return { "__proposal__": True, "image_id": self.image_id, "filename": self.filename, "pred_boxes": self.pred_boxes.tolist(), "pred_classes": self.pred_classes.tolist(), "pred_scores": self.pred_scores.tolist(), "gt_classes": self.gt_classes.tolist(), "gt_boxes": self.gt_boxes.tolist(), "gt_ious": self.gt_ious.tolist(), "gt_fg_classes": self.gt_fg_classes.tolist() } def __str__(self): pp = pprint.PrettyPrinter(indent=2) return pp.pformat(self.to_dict()) def __len__(self): return len(self.pred_boxes) def __getitem__(self, item): return Proposal( image_id=self.image_id, filename=self.filename, pred_boxes=self.pred_boxes[item], pred_classes=self.pred_classes[item], pred_scores=self.pred_scores[item], gt_classes=self.gt_classes[item], gt_boxes=self.gt_boxes[item], gt_ious=self.gt_ious[item], gt_fg_classes=self.gt_fg_classes[item] ) class ProposalEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Proposal): return obj.to_dict() return json.JSONEncoder.default(self, obj) def asProposal(dict): if '__proposal__' in dict: return Proposal( dict["image_id"], dict["filename"], np.array(dict["pred_boxes"]), np.array(dict["pred_classes"]), np.array(dict["pred_scores"]), np.array(dict["gt_classes"]), np.array(dict["gt_boxes"]), np.array(dict["gt_ious"]), np.array(dict["gt_fg_classes"]) ) return dict class PersistentProposalList(list): """ A data structure that stores the proposals for a dataset. Args: filename (str, optional): filename indicating where to cache """ def __init__(self, filename=None): super(PersistentProposalList, self).__init__() self.filename = filename def load(self): """ Load from cache. Return: whether succeed """ if os.path.exists(self.filename): print("Reading from cache: {}".format(self.filename)) with open(self.filename, "r") as f: self.extend(json.load(f, object_hook=asProposal)) return True else: return False def flush(self): """ Flush to cache. """ os.makedirs(os.path.dirname(self.filename), exist_ok=True) with open(self.filename, "w") as f: json.dump(self, f, cls=ProposalEncoder) print("Write to cache: {}".format(self.filename)) def flatten(proposal_list, max_number=10000): """ Flatten a list of proposals Args: proposal_list (list): a list of proposals grouped by images max_number (int): maximum number of kept proposals for each image """ flattened_list = [] for proposals in proposal_list: for i in range(min(len(proposals), max_number)): flattened_list.append(proposals[i:i+1]) return flattened_list class ProposalDataset(datasets.VisionDataset): """ A dataset for proposals. Args: proposal_list (list): list of Proposal transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` crop_func: (ExpandCrop, optional): """ def __init__(self, proposal_list: List[Proposal], transform: Optional[Callable] = None, crop_func=None): super(ProposalDataset, self).__init__("", transform=transform) self.proposal_list = list(filter(lambda p: len(p) > 0, proposal_list)) # remove images without proposals self.loader = default_loader self.crop_func = crop_func def __getitem__(self, index: int): # get proposals for the index-th image proposals = self.proposal_list[index] img = self.loader(proposals.filename) # random sample a proposal proposal = proposals[random.randint(0, len(proposals)-1)] image_width, image_height = img.width, img.height # proposal_dict = proposal.to_dict() # proposal_dict.update(width=img.width, height=img.height) # crop the proposal from the whole image x1, y1, x2, y2 = proposal.pred_boxes top, left, height, width = int(y1), int(x1), int(y2 - y1), int(x2 - x1) if self.crop_func is not None: top, left, height, width = self.crop_func(img, top, left, height, width) img = crop(img, top, left, height, width) if self.transform is not None: img = self.transform(img) return img, { "image_id": proposal.image_id, "filename": proposal.filename, "pred_boxes": proposal.pred_boxes.astype(np.float), "pred_classes": proposal.pred_classes.astype(np.long), "pred_scores": proposal.pred_scores.astype(np.float), "gt_classes": proposal.gt_classes.astype(np.long), "gt_boxes": proposal.gt_boxes.astype(np.float), "gt_ious": proposal.gt_ious.astype(np.float), "gt_fg_classes": proposal.gt_fg_classes.astype(np.long), "width": image_width, "height": image_height } def __len__(self): return len(self.proposal_list) class ExpandCrop: """ The input of the bounding box adaptor (the crops of objects) will be larger than the original predicted box, so that the bounding box adapter could access more location information. """ def __init__(self, expand=1.): self.expand = expand def __call__(self, img, top, left, height, width): cx = left + width / 2. cy = top + height / 2. height = round(height * self.expand) width = round(width * self.expand) new_top = round(cy - height / 2.) new_left = round(cx - width / 2.) return new_top, new_left, height, width ================================================ FILE: tllib/alignment/dan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, Sequence import torch import torch.nn as nn from tllib.modules.classifier import Classifier as ClassifierBase __all__ = ['MultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier'] class MultipleKernelMaximumMeanDiscrepancy(nn.Module): r"""The Multiple Kernel Maximum Mean Discrepancy (MK-MMD) used in `Learning Transferable Features with Deep Adaptation Networks (ICML 2015) `_ Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t` of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate activations as :math:`\{z_i^s\}_{i=1}^{n_s}` and :math:`\{z_i^t\}_{i=1}^{n_t}`. The MK-MMD :math:`D_k (P, Q)` between probability distributions P and Q is defined as .. math:: D_k(P, Q) \triangleq \| E_p [\phi(z^s)] - E_q [\phi(z^t)] \|^2_{\mathcal{H}_k}, :math:`k` is a kernel function in the function space .. math:: \mathcal{K} \triangleq \{ k=\sum_{u=1}^{m}\beta_{u} k_{u} \} where :math:`k_{u}` is a single kernel. Using kernel trick, MK-MMD can be computed as .. math:: \hat{D}_k(P, Q) &= \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} k(z_i^{s}, z_j^{s})\\ &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} k(z_i^{t}, z_j^{t})\\ &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} k(z_i^{s}, z_j^{t}).\\ Args: kernels (tuple(torch.nn.Module)): kernel functions. linear (bool): whether use the linear version of DAN. Default: False Inputs: - z_s (tensor): activations from the source domain, :math:`z^s` - z_t (tensor): activations from the target domain, :math:`z^t` Shape: - Inputs: :math:`(minibatch, *)` where * means any dimension - Outputs: scalar .. note:: Activations :math:`z^{s}` and :math:`z^{t}` must have the same shape. .. note:: The kernel values will add up when there are multiple kernels. Examples:: >>> from tllib.modules.kernels import GaussianKernel >>> feature_dim = 1024 >>> batch_size = 10 >>> kernels = (GaussianKernel(alpha=0.5), GaussianKernel(alpha=1.), GaussianKernel(alpha=2.)) >>> loss = MultipleKernelMaximumMeanDiscrepancy(kernels) >>> # features from source domain and target domain >>> z_s, z_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> output = loss(z_s, z_t) """ def __init__(self, kernels: Sequence[nn.Module], linear: Optional[bool] = False): super(MultipleKernelMaximumMeanDiscrepancy, self).__init__() self.kernels = kernels self.index_matrix = None self.linear = linear def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor: features = torch.cat([z_s, z_t], dim=0) batch_size = int(z_s.size(0)) self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s.device) kernel_matrix = sum([kernel(features) for kernel in self.kernels]) # Add up the matrix of each kernel # Add 2 / (n-1) to make up for the value on the diagonal # to ensure loss is positive in the non-linear version loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1) return loss def _update_index_matrix(batch_size: int, index_matrix: Optional[torch.Tensor] = None, linear: Optional[bool] = True) -> torch.Tensor: r""" Update the `index_matrix` which convert `kernel_matrix` to loss. If `index_matrix` is a tensor with shape (2 x batch_size, 2 x batch_size), then return `index_matrix`. Else return a new tensor with shape (2 x batch_size, 2 x batch_size). """ if index_matrix is None or index_matrix.size(0) != batch_size * 2: index_matrix = torch.zeros(2 * batch_size, 2 * batch_size) if linear: for i in range(batch_size): s1, s2 = i, (i + 1) % batch_size t1, t2 = s1 + batch_size, s2 + batch_size index_matrix[s1, s2] = 1. / float(batch_size) index_matrix[t1, t2] = 1. / float(batch_size) index_matrix[s1, t2] = -1. / float(batch_size) index_matrix[s2, t1] = -1. / float(batch_size) else: for i in range(batch_size): for j in range(batch_size): if i != j: index_matrix[i][j] = 1. / float(batch_size * (batch_size - 1)) index_matrix[i + batch_size][j + batch_size] = 1. / float(batch_size * (batch_size - 1)) for i in range(batch_size): for j in range(batch_size): index_matrix[i][j + batch_size] = -1. / float(batch_size * batch_size) index_matrix[i + batch_size][j] = -1. / float(batch_size * batch_size) return index_matrix class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) ================================================ FILE: tllib/alignment/dann.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from tllib.modules.grl import WarmStartGradientReverseLayer from tllib.modules.classifier import Classifier as ClassifierBase from tllib.utils.metric import binary_accuracy, accuracy __all__ = ['DomainAdversarialLoss'] class DomainAdversarialLoss(nn.Module): r""" The Domain Adversarial Loss proposed in `Domain-Adversarial Training of Neural Networks (ICML 2015) `_ Domain adversarial loss measures the domain discrepancy through training a domain discriminator. Given domain discriminator :math:`D`, feature representation :math:`f`, the definition of DANN loss is .. math:: loss(\mathcal{D}_s, \mathcal{D}_t) = \mathbb{E}_{x_i^s \sim \mathcal{D}_s} \text{log}[D(f_i^s)] + \mathbb{E}_{x_j^t \sim \mathcal{D}_t} \text{log}[1-D(f_j^t)]. Args: domain_discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of features. Its input shape is (N, F) and output shape is (N, 1) reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` grl (WarmStartGradientReverseLayer, optional): Default: None. Inputs: - f_s (tensor): feature representations on source domain, :math:`f^s` - f_t (tensor): feature representations on target domain, :math:`f^t` - w_s (tensor, optional): a rescaling weight given to each instance from source domain. - w_t (tensor, optional): a rescaling weight given to each instance from target domain. Shape: - f_s, f_t: :math:`(N, F)` where F means the dimension of input features. - Outputs: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, )`. Examples:: >>> from tllib.modules.domain_discriminator import DomainDiscriminator >>> discriminator = DomainDiscriminator(in_feature=1024, hidden_size=1024) >>> loss = DomainAdversarialLoss(discriminator, reduction='mean') >>> # features from source domain and target domain >>> f_s, f_t = torch.randn(20, 1024), torch.randn(20, 1024) >>> # If you want to assign different weights to each instance, you should pass in w_s and w_t >>> w_s, w_t = torch.randn(20), torch.randn(20) >>> output = loss(f_s, f_t, w_s, w_t) """ def __init__(self, domain_discriminator: nn.Module, reduction: Optional[str] = 'mean', grl: Optional = None, sigmoid=True): super(DomainAdversarialLoss, self).__init__() self.grl = WarmStartGradientReverseLayer(alpha=1., lo=0., hi=1., max_iters=1000, auto_step=True) if grl is None else grl self.domain_discriminator = domain_discriminator self.sigmoid = sigmoid self.reduction = reduction self.bce = lambda input, target, weight: \ F.binary_cross_entropy(input, target, weight=weight, reduction=reduction) self.domain_discriminator_accuracy = None def forward(self, f_s: torch.Tensor, f_t: torch.Tensor, w_s: Optional[torch.Tensor] = None, w_t: Optional[torch.Tensor] = None) -> torch.Tensor: f = self.grl(torch.cat((f_s, f_t), dim=0)) d = self.domain_discriminator(f) if self.sigmoid: d_s, d_t = d.chunk(2, dim=0) d_label_s = torch.ones((f_s.size(0), 1)).to(f_s.device) d_label_t = torch.zeros((f_t.size(0), 1)).to(f_t.device) self.domain_discriminator_accuracy = 0.5 * ( binary_accuracy(d_s, d_label_s) + binary_accuracy(d_t, d_label_t)) if w_s is None: w_s = torch.ones_like(d_label_s) if w_t is None: w_t = torch.ones_like(d_label_t) return 0.5 * ( F.binary_cross_entropy(d_s, d_label_s, weight=w_s.view_as(d_s), reduction=self.reduction) + F.binary_cross_entropy(d_t, d_label_t, weight=w_t.view_as(d_t), reduction=self.reduction) ) else: d_label = torch.cat(( torch.ones((f_s.size(0),)).to(f_s.device), torch.zeros((f_t.size(0),)).to(f_t.device), )).long() if w_s is None: w_s = torch.ones((f_s.size(0),)).to(f_s.device) if w_t is None: w_t = torch.ones((f_t.size(0),)).to(f_t.device) self.domain_discriminator_accuracy = accuracy(d, d_label) loss = F.cross_entropy(d, d_label, reduction='none') * torch.cat([w_s, w_t], dim=0) if self.reduction == "mean": return loss.mean() elif self.reduction == "sum": return loss.sum() elif self.reduction == "none": return loss else: raise NotImplementedError(self.reduction) class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) ================================================ FILE: tllib/alignment/jan.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, Sequence import torch import torch.nn as nn from tllib.modules.classifier import Classifier as ClassifierBase from tllib.modules.grl import GradientReverseLayer from tllib.modules.kernels import GaussianKernel from tllib.alignment.dan import _update_index_matrix __all__ = ['JointMultipleKernelMaximumMeanDiscrepancy', 'ImageClassifier'] class JointMultipleKernelMaximumMeanDiscrepancy(nn.Module): r"""The Joint Multiple Kernel Maximum Mean Discrepancy (JMMD) used in `Deep Transfer Learning with Joint Adaptation Networks (ICML 2017) `_ Given source domain :math:`\mathcal{D}_s` of :math:`n_s` labeled points and target domain :math:`\mathcal{D}_t` of :math:`n_t` unlabeled points drawn i.i.d. from P and Q respectively, the deep networks will generate activations in layers :math:`\mathcal{L}` as :math:`\{(z_i^{s1}, ..., z_i^{s|\mathcal{L}|})\}_{i=1}^{n_s}` and :math:`\{(z_i^{t1}, ..., z_i^{t|\mathcal{L}|})\}_{i=1}^{n_t}`. The empirical estimate of :math:`\hat{D}_{\mathcal{L}}(P, Q)` is computed as the squared distance between the empirical kernel mean embeddings as .. math:: \hat{D}_{\mathcal{L}}(P, Q) &= \dfrac{1}{n_s^2} \sum_{i=1}^{n_s}\sum_{j=1}^{n_s} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{sl}) \\ &+ \dfrac{1}{n_t^2} \sum_{i=1}^{n_t}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{tl}, z_j^{tl}) \\ &- \dfrac{2}{n_s n_t} \sum_{i=1}^{n_s}\sum_{j=1}^{n_t} \prod_{l\in\mathcal{L}} k^l(z_i^{sl}, z_j^{tl}). \\ Args: kernels (tuple(tuple(torch.nn.Module))): kernel functions, where `kernels[r]` corresponds to kernel :math:`k^{\mathcal{L}[r]}`. linear (bool): whether use the linear version of JAN. Default: False thetas (list(Theta): use adversarial version JAN if not None. Default: None Inputs: - z_s (tuple(tensor)): multiple layers' activations from the source domain, :math:`z^s` - z_t (tuple(tensor)): multiple layers' activations from the target domain, :math:`z^t` Shape: - :math:`z^{sl}` and :math:`z^{tl}`: :math:`(minibatch, *)` where * means any dimension - Outputs: scalar .. note:: Activations :math:`z^{sl}` and :math:`z^{tl}` must have the same shape. .. note:: The kernel values will add up when there are multiple kernels for a certain layer. Examples:: >>> feature_dim = 1024 >>> batch_size = 10 >>> layer1_kernels = (GaussianKernel(alpha=0.5), GaussianKernel(1.), GaussianKernel(2.)) >>> layer2_kernels = (GaussianKernel(1.), ) >>> loss = JointMultipleKernelMaximumMeanDiscrepancy((layer1_kernels, layer2_kernels)) >>> # layer1 features from source domain and target domain >>> z1_s, z1_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # layer2 features from source domain and target domain >>> z2_s, z2_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> output = loss((z1_s, z2_s), (z1_t, z2_t)) """ def __init__(self, kernels: Sequence[Sequence[nn.Module]], linear: Optional[bool] = True, thetas: Sequence[nn.Module] = None): super(JointMultipleKernelMaximumMeanDiscrepancy, self).__init__() self.kernels = kernels self.index_matrix = None self.linear = linear if thetas: self.thetas = thetas else: self.thetas = [nn.Identity() for _ in kernels] def forward(self, z_s: torch.Tensor, z_t: torch.Tensor) -> torch.Tensor: batch_size = int(z_s[0].size(0)) self.index_matrix = _update_index_matrix(batch_size, self.index_matrix, self.linear).to(z_s[0].device) kernel_matrix = torch.ones_like(self.index_matrix) for layer_z_s, layer_z_t, layer_kernels, theta in zip(z_s, z_t, self.kernels, self.thetas): layer_features = torch.cat([layer_z_s, layer_z_t], dim=0) layer_features = theta(layer_features) kernel_matrix *= sum( [kernel(layer_features) for kernel in layer_kernels]) # Add up the matrix of each kernel # Add 2 / (n-1) to make up for the value on the diagonal # to ensure loss is positive in the non-linear version loss = (kernel_matrix * self.index_matrix).sum() + 2. / float(batch_size - 1) return loss class Theta(nn.Module): """ maximize loss respect to :math:`\theta` minimize loss respect to features """ def __init__(self, dim: int): super(Theta, self).__init__() self.grl1 = GradientReverseLayer() self.grl2 = GradientReverseLayer() self.layer1 = nn.Linear(dim, dim) nn.init.eye_(self.layer1.weight) nn.init.zeros_(self.layer1.bias) def forward(self, features: torch.Tensor) -> torch.Tensor: features = self.grl1(features) return self.grl2(self.layer1(features)) class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) ================================================ FILE: tllib/alignment/mcd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import torch.nn as nn import torch def classifier_discrepancy(predictions1: torch.Tensor, predictions2: torch.Tensor) -> torch.Tensor: r"""The `Classifier Discrepancy` in `Maximum Classifier Discrepancy for Unsupervised Domain Adaptation (CVPR 2018) `_. The classfier discrepancy between predictions :math:`p_1` and :math:`p_2` can be described as: .. math:: d(p_1, p_2) = \dfrac{1}{K} \sum_{k=1}^K | p_{1k} - p_{2k} |, where K is number of classes. Args: predictions1 (torch.Tensor): Classifier predictions :math:`p_1`. Expected to contain raw, normalized scores for each class predictions2 (torch.Tensor): Classifier predictions :math:`p_2` """ return torch.mean(torch.abs(predictions1 - predictions2)) def entropy(predictions: torch.Tensor) -> torch.Tensor: r"""Entropy of N predictions :math:`(p_1, p_2, ..., p_N)`. The definition is: .. math:: d(p_1, p_2, ..., p_N) = -\dfrac{1}{K} \sum_{k=1}^K \log \left( \dfrac{1}{N} \sum_{i=1}^N p_{ik} \right) where K is number of classes. .. note:: This entropy function is specifically used in MCD and different from the usual :meth:`~tllib.modules.entropy.entropy` function. Args: predictions (torch.Tensor): Classifier predictions. Expected to contain raw, normalized scores for each class """ return -torch.mean(torch.log(torch.mean(predictions, 0) + 1e-6)) class ImageClassifierHead(nn.Module): r"""Classifier Head for MCD. Args: in_features (int): Dimension of input features num_classes (int): Number of classes bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024 Shape: - Inputs: :math:`(minibatch, F)` where F = `in_features`. - Output: :math:`(minibatch, C)` where C = `num_classes`. """ def __init__(self, in_features: int, num_classes: int, bottleneck_dim: Optional[int] = 1024, pool_layer=None): super(ImageClassifierHead, self).__init__() self.num_classes = num_classes if pool_layer is None: self.pool_layer = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() ) else: self.pool_layer = pool_layer self.head = nn.Sequential( nn.Dropout(0.5), nn.Linear(in_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(0.5), nn.Linear(bottleneck_dim, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Linear(bottleneck_dim, num_classes) ) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.head(self.pool_layer(inputs)) ================================================ FILE: tllib/alignment/mdd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, List, Dict, Tuple, Callable import torch.nn as nn import torch.nn.functional as F import torch from tllib.modules.grl import WarmStartGradientReverseLayer class MarginDisparityDiscrepancy(nn.Module): r"""The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) `_. MDD can measure the distribution discrepancy in domain adaptation. The :math:`y^s` and :math:`y^t` are logits output by the main head on the source and target domain respectively. The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial head. The definition can be described as: .. math:: \mathcal{D}_{\gamma}(\hat{\mathcal{S}}, \hat{\mathcal{T}}) = -\gamma \mathbb{E}_{y^s, y_{adv}^s \sim\hat{\mathcal{S}}} L_s (y^s, y_{adv}^s) + \mathbb{E}_{y^t, y_{adv}^t \sim\hat{\mathcal{T}}} L_t (y^t, y_{adv}^t), where :math:`\gamma` is a margin hyper-parameter, :math:`L_s` refers to the disparity function defined on the source domain and :math:`L_t` refers to the disparity function defined on the target domain. Args: source_disparity (callable): The disparity function defined on the source domain, :math:`L_s`. target_disparity (callable): The disparity function defined on the target domain, :math:`L_t`. margin (float): margin :math:`\gamma`. Default: 4 reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - y_s: output :math:`y^s` by the main head on the source domain - y_s_adv: output :math:`y^s` by the adversarial head on the source domain - y_t: output :math:`y^t` by the main head on the target domain - y_t_adv: output :math:`y_{adv}^t` by the adversarial head on the target domain - w_s (optional): instance weights for source domain - w_t (optional): instance weights for target domain Examples:: >>> num_outputs = 2 >>> batch_size = 10 >>> loss = MarginDisparityDiscrepancy(margin=4., source_disparity=F.l1_loss, target_disparity=F.l1_loss) >>> # output from source domain and target domain >>> y_s, y_t = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs) >>> # adversarial output from source domain and target domain >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs) >>> output = loss(y_s, y_s_adv, y_t, y_t_adv) """ def __init__(self, source_disparity: Callable, target_disparity: Callable, margin: Optional[float] = 4, reduction: Optional[str] = 'mean'): super(MarginDisparityDiscrepancy, self).__init__() self.margin = margin self.reduction = reduction self.source_disparity = source_disparity self.target_disparity = target_disparity def forward(self, y_s: torch.Tensor, y_s_adv: torch.Tensor, y_t: torch.Tensor, y_t_adv: torch.Tensor, w_s: Optional[torch.Tensor] = None, w_t: Optional[torch.Tensor] = None) -> torch.Tensor: source_loss = -self.margin * self.source_disparity(y_s, y_s_adv) target_loss = self.target_disparity(y_t, y_t_adv) if w_s is None: w_s = torch.ones_like(source_loss) source_loss = source_loss * w_s if w_t is None: w_t = torch.ones_like(target_loss) target_loss = target_loss * w_t loss = source_loss + target_loss if self.reduction == 'mean': loss = loss.mean() elif self.reduction == 'sum': loss = loss.sum() return loss class ClassificationMarginDisparityDiscrepancy(MarginDisparityDiscrepancy): r""" The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) `_. It measures the distribution discrepancy in domain adaptation for classification. When margin is equal to 1, it's also called disparity discrepancy (DD). The :math:`y^s` and :math:`y^t` are logits output by the main classifier on the source and target domain respectively. The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial classifier. They are expected to contain raw, unnormalized scores for each class. The definition can be described as: .. math:: \mathcal{D}_{\gamma}(\hat{\mathcal{S}}, \hat{\mathcal{T}}) = \gamma \mathbb{E}_{y^s, y_{adv}^s \sim\hat{\mathcal{S}}} \log\left(\frac{\exp(y_{adv}^s[h_{y^s}])}{\sum_j \exp(y_{adv}^s[j])}\right) + \mathbb{E}_{y^t, y_{adv}^t \sim\hat{\mathcal{T}}} \log\left(1-\frac{\exp(y_{adv}^t[h_{y^t}])}{\sum_j \exp(y_{adv}^t[j])}\right), where :math:`\gamma` is a margin hyper-parameter and :math:`h_y` refers to the predicted label when the logits output is :math:`y`. You can see more details in `Bridging Theory and Algorithm for Domain Adaptation `_. Args: margin (float): margin :math:`\gamma`. Default: 4 reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - y_s: logits output :math:`y^s` by the main classifier on the source domain - y_s_adv: logits output :math:`y^s` by the adversarial classifier on the source domain - y_t: logits output :math:`y^t` by the main classifier on the target domain - y_t_adv: logits output :math:`y_{adv}^t` by the adversarial classifier on the target domain Shape: - Inputs: :math:`(minibatch, C)` where C = number of classes, or :math:`(minibatch, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Output: scalar. If :attr:`reduction` is ``'none'``, then the same size as the target: :math:`(minibatch)`, or :math:`(minibatch, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. Examples:: >>> num_classes = 2 >>> batch_size = 10 >>> loss = ClassificationMarginDisparityDiscrepancy(margin=4.) >>> # logits output from source domain and target domain >>> y_s, y_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> # adversarial logits output from source domain and target domain >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> output = loss(y_s, y_s_adv, y_t, y_t_adv) """ def __init__(self, margin: Optional[float] = 4, **kwargs): def source_discrepancy(y: torch.Tensor, y_adv: torch.Tensor): _, prediction = y.max(dim=1) return F.cross_entropy(y_adv, prediction, reduction='none') def target_discrepancy(y: torch.Tensor, y_adv: torch.Tensor): _, prediction = y.max(dim=1) return -F.nll_loss(shift_log(1. - F.softmax(y_adv, dim=1)), prediction, reduction='none') super(ClassificationMarginDisparityDiscrepancy, self).__init__(source_discrepancy, target_discrepancy, margin, **kwargs) class RegressionMarginDisparityDiscrepancy(MarginDisparityDiscrepancy): r""" The margin disparity discrepancy (MDD) proposed in `Bridging Theory and Algorithm for Domain Adaptation (ICML 2019) `_. It measures the distribution discrepancy in domain adaptation for regression. The :math:`y^s` and :math:`y^t` are logits output by the main regressor on the source and target domain respectively. The :math:`y_{adv}^s` and :math:`y_{adv}^t` are logits output by the adversarial regressor. They are expected to contain ``normalized`` values for each factors. The definition can be described as: .. math:: \mathcal{D}_{\gamma}(\hat{\mathcal{S}}, \hat{\mathcal{T}}) = -\gamma \mathbb{E}_{y^s, y_{adv}^s \sim\hat{\mathcal{S}}} L (y^s, y_{adv}^s) + \mathbb{E}_{y^t, y_{adv}^t \sim\hat{\mathcal{T}}} L (y^t, y_{adv}^t), where :math:`\gamma` is a margin hyper-parameter and :math:`L` refers to the disparity function defined on both domains. You can see more details in `Bridging Theory and Algorithm for Domain Adaptation `_. Args: loss_function (callable): The disparity function defined on both domains, :math:`L`. margin (float): margin :math:`\gamma`. Default: 1 reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - y_s: logits output :math:`y^s` by the main regressor on the source domain - y_s_adv: logits output :math:`y^s` by the adversarial regressor on the source domain - y_t: logits output :math:`y^t` by the main regressor on the target domain - y_t_adv: logits output :math:`y_{adv}^t` by the adversarial regressor on the target domain Shape: - Inputs: :math:`(minibatch, F)` where F = number of factors, or :math:`(minibatch, F, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Output: scalar. The same size as the target: :math:`(minibatch)`, or :math:`(minibatch, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. Examples:: >>> num_outputs = 2 >>> batch_size = 10 >>> loss = RegressionMarginDisparityDiscrepancy(margin=4., loss_function=F.l1_loss) >>> # output from source domain and target domain >>> y_s, y_t = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs) >>> # adversarial output from source domain and target domain >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_outputs), torch.randn(batch_size, num_outputs) >>> output = loss(y_s, y_s_adv, y_t, y_t_adv) """ def __init__(self, margin: Optional[float] = 1, loss_function=F.l1_loss, **kwargs): def source_discrepancy(y: torch.Tensor, y_adv: torch.Tensor): return loss_function(y_adv, y.detach(), reduction='none') def target_discrepancy(y: torch.Tensor, y_adv: torch.Tensor): return loss_function(y_adv, y.detach(), reduction='none') super(RegressionMarginDisparityDiscrepancy, self).__init__(source_discrepancy, target_discrepancy, margin, **kwargs) def shift_log(x: torch.Tensor, offset: Optional[float] = 1e-6) -> torch.Tensor: r""" First shift, then calculate log, which can be described as: .. math:: y = \max(\log(x+\text{offset}), 0) Used to avoid the gradient explosion problem in log(x) function when x=0. Args: x (torch.Tensor): input tensor offset (float, optional): offset size. Default: 1e-6 .. note:: Input tensor falls in [0., 1.] and the output tensor falls in [-log(offset), 0] """ return torch.log(torch.clamp(x + offset, max=1.)) class GeneralModule(nn.Module): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: nn.Module, head: nn.Module, adv_head: nn.Module, grl: Optional[WarmStartGradientReverseLayer] = None, finetune: Optional[bool] = True): super(GeneralModule, self).__init__() self.backbone = backbone self.num_classes = num_classes self.bottleneck = bottleneck self.head = head self.adv_head = adv_head self.finetune = finetune self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False) if grl is None else grl def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """""" features = self.backbone(x) features = self.bottleneck(features) outputs = self.head(features) features_adv = self.grl_layer(features) outputs_adv = self.adv_head(features_adv) if self.training: return outputs, outputs_adv else: return outputs def step(self): """ Gradually increase :math:`\lambda` in GRL layer. """ self.grl_layer.step() def get_parameters(self, base_lr=1.0) -> List[Dict]: """ Return a parameters list which decides optimization hyper-parameters, such as the relative learning rate of each layer. """ params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else base_lr}, {"params": self.bottleneck.parameters(), "lr": base_lr}, {"params": self.head.parameters(), "lr": base_lr}, {"params": self.adv_head.parameters(), "lr": base_lr} ] return params class ImageClassifier(GeneralModule): r"""Classifier for MDD. Classifier for MDD has one backbone, one bottleneck, while two classifier heads. The first classifier head is used for final predictions. The adversarial classifier head is only used when calculating MarginDisparityDiscrepancy. Args: backbone (torch.nn.Module): Any backbone to extract 1-d features from data num_classes (int): Number of classes bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024 width (int, optional): Feature dimension of the classifier head. Default: 1024 grl (nn.Module): Gradient reverse layer. Will use default parameters if None. Default: None. finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True Inputs: - x (tensor): input data Outputs: - outputs: logits outputs by the main classifier - outputs_adv: logits outputs by the adversarial classifier Shape: - x: :math:`(minibatch, *)`, same shape as the input of the `backbone`. - outputs, outputs_adv: :math:`(minibatch, C)`, where C means the number of classes. .. note:: Remember to call function `step()` after function `forward()` **during training phase**! For instance, >>> # x is inputs, classifier is an ImageClassifier >>> outputs, outputs_adv = classifier(x) >>> classifier.step() """ def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024, grl: Optional[WarmStartGradientReverseLayer] = None, finetune=True, pool_layer=None): grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False) if grl is None else grl if pool_layer is None: pool_layer = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() ) bottleneck = nn.Sequential( pool_layer, nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) bottleneck[1].weight.data.normal_(0, 0.005) bottleneck[1].bias.data.fill_(0.1) # The classifier head used for final predictions. head = nn.Sequential( nn.Linear(bottleneck_dim, width), nn.ReLU(), nn.Dropout(0.5), nn.Linear(width, num_classes) ) # The adversarial classifier head adv_head = nn.Sequential( nn.Linear(bottleneck_dim, width), nn.ReLU(), nn.Dropout(0.5), nn.Linear(width, num_classes) ) for dep in range(2): head[dep * 3].weight.data.normal_(0, 0.01) head[dep * 3].bias.data.fill_(0.0) adv_head[dep * 3].weight.data.normal_(0, 0.01) adv_head[dep * 3].bias.data.fill_(0.0) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, head, adv_head, grl_layer, finetune) class ImageRegressor(GeneralModule): r"""Regressor for MDD. Regressor for MDD has one backbone, one bottleneck, while two regressor heads. The first regressor head is used for final predictions. The adversarial regressor head is only used when calculating MarginDisparityDiscrepancy. Args: backbone (torch.nn.Module): Any backbone to extract 1-d features from data num_factors (int): Number of factors bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1024 width (int, optional): Feature dimension of the classifier head. Default: 1024 finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True Inputs: - x (Tensor): input data Outputs: (outputs, outputs_adv) - outputs: outputs by the main regressor - outputs_adv: outputs by the adversarial regressor Shape: - x: :math:`(minibatch, *)`, same shape as the input of the `backbone`. - outputs, outputs_adv: :math:`(minibatch, F)`, where F means the number of factors. .. note:: Remember to call function `step()` after function `forward()` **during training phase**! For instance, >>> # x is inputs, regressor is an ImageRegressor >>> outputs, outputs_adv = regressor(x) >>> regressor.step() """ def __init__(self, backbone: nn.Module, num_factors: int, bottleneck = None, head=None, adv_head=None, bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024, finetune=True): grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False) if bottleneck is None: bottleneck = nn.Sequential( nn.Conv2d(backbone.out_features, bottleneck_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(bottleneck_dim), nn.ReLU(), ) # The regressor head used for final predictions. if head is None: head = nn.Sequential( nn.Conv2d(bottleneck_dim, width, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(width), nn.ReLU(), nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(width), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(width, num_factors), nn.Sigmoid() ) for layer in head: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, 0, 0.01) nn.init.constant_(layer.bias, 0) # The adversarial regressor head if adv_head is None: adv_head = nn.Sequential( nn.Conv2d(bottleneck_dim, width, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(width), nn.ReLU(), nn.Conv2d(width, width, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(width), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten(), nn.Linear(width, num_factors), nn.Sigmoid() ) for layer in adv_head: if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Linear): nn.init.normal_(layer.weight, 0, 0.01) nn.init.constant_(layer.bias, 0) super(ImageRegressor, self).__init__(backbone, num_factors, bottleneck, head, adv_head, grl_layer, finetune) self.num_factors = num_factors ================================================ FILE: tllib/alignment/osbp.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from tllib.modules.classifier import Classifier as ClassifierBase from tllib.modules.grl import GradientReverseLayer class UnknownClassBinaryCrossEntropy(nn.Module): r""" Binary cross entropy loss to make a boundary for unknown samples, proposed by `Open Set Domain Adaptation by Backpropagation (ECCV 2018) `_. Given a sample on target domain :math:`x_t` and its classifcation outputs :math:`y`, the binary cross entropy loss is defined as .. math:: L_{\text{adv}}(x_t) = -t \text{log}(p(y=C+1|x_t)) - (1-t)\text{log}(1-p(y=C+1|x_t)) where t is a hyper-parameter and C is the number of known classes. Args: t (float): Predefined hyper-parameter. Default: 0.5 Inputs: - y (tensor): classification outputs (before softmax). Shape: - y: :math:`(minibatch, C+1)` where C is the number of known classes. - Outputs: scalar """ def __init__(self, t: Optional[float]=0.5): super(UnknownClassBinaryCrossEntropy, self).__init__() self.t = t def forward(self, y): # y : N x (C+1) softmax_output = F.softmax(y, dim=1) unknown_class_prob = softmax_output[:, -1].contiguous().view(-1, 1) known_class_prob = 1. - unknown_class_prob unknown_target = torch.ones((y.size(0), 1)).to(y.device) * self.t known_target = 1. - unknown_target return - torch.mean(unknown_target * torch.log(unknown_class_prob + 1e-6)) \ - torch.mean(known_target * torch.log(known_class_prob + 1e-6)) class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(), nn.Linear(bottleneck_dim, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) self.grl = GradientReverseLayer() def forward(self, x: torch.Tensor, grad_reverse: Optional[bool] = False): features = self.pool_layer(self.backbone(x)) features = self.bottleneck(features) if grad_reverse: features = self.grl(features) outputs = self.head(features) if self.training: return outputs, features else: return outputs ================================================ FILE: tllib/alignment/regda.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from tllib.modules.gl import WarmStartGradientLayer from tllib.utils.metric.keypoint_detection import get_max_preds class FastPseudoLabelGenerator2d(nn.Module): def __init__(self, sigma=2): super().__init__() self.sigma = sigma def forward(self, heatmap: torch.Tensor): heatmap = heatmap.detach() height, width = heatmap.shape[-2:] idx = heatmap.flatten(-2).argmax(dim=-1) # B, K pred_h, pred_w = idx.div(width, rounding_mode='floor'), idx.remainder(width) # B, K delta_h = torch.arange(height, device=heatmap.device) - pred_h.unsqueeze(-1) # B, K, H delta_w = torch.arange(width, device=heatmap.device) - pred_w.unsqueeze(-1) # B, K, W gaussian = (delta_h.square().unsqueeze(-1) + delta_w.square().unsqueeze(-2)).div(-2 * self.sigma * self.sigma).exp() # B, K, H, W ground_truth = F.threshold(gaussian, threshold=1e-2, value=0.) ground_false = (ground_truth.sum(dim=1, keepdim=True) - ground_truth).clamp(0., 1.) return ground_truth, ground_false class PseudoLabelGenerator2d(nn.Module): """ Generate ground truth heatmap and ground false heatmap from a prediction. Args: num_keypoints (int): Number of keypoints height (int): height of the heatmap. Default: 64 width (int): width of the heatmap. Default: 64 sigma (int): sigma parameter when generate the heatmap. Default: 2 Inputs: - y: predicted heatmap Outputs: - ground_truth: heatmap conforming to Gaussian distribution - ground_false: ground false heatmap Shape: - y: :math:`(minibatch, K, H, W)` where K means the number of keypoints, H and W is the height and width of the heatmap respectively. - ground_truth: :math:`(minibatch, K, H, W)` - ground_false: :math:`(minibatch, K, H, W)` """ def __init__(self, num_keypoints, height=64, width=64, sigma=2): super(PseudoLabelGenerator2d, self).__init__() self.height = height self.width = width self.sigma = sigma heatmaps = np.zeros((width, height, height, width), dtype=np.float32) tmp_size = sigma * 3 for mu_x in range(width): for mu_y in range(height): # Check that any part of the gaussian is in-bounds ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] # Generate gaussian size = 2 * tmp_size + 1 x = np.arange(0, size, 1, np.float32) y = x[:, np.newaxis] x0 = y0 = size // 2 # The gaussian is not normalized, we want the center value to equal 1 g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) # Usable gaussian range g_x = max(0, -ul[0]), min(br[0], width) - ul[0] g_y = max(0, -ul[1]), min(br[1], height) - ul[1] # Image range img_x = max(0, ul[0]), min(br[0], width) img_y = max(0, ul[1]), min(br[1], height) heatmaps[mu_x][mu_y][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \ g[g_y[0]:g_y[1], g_x[0]:g_x[1]] self.heatmaps = heatmaps self.false_matrix = 1. - np.eye(num_keypoints, dtype=np.float32) def forward(self, y): B, K, H, W = y.shape y = y.detach() preds, max_vals = get_max_preds(y.cpu().numpy()) # B x K x (x, y) preds = preds.reshape(-1, 2).astype(np.int) ground_truth = self.heatmaps[preds[:, 0], preds[:, 1], :, :].copy().reshape(B, K, H, W).copy() ground_false = ground_truth.reshape(B, K, -1).transpose((0, 2, 1)) ground_false = ground_false.dot(self.false_matrix).clip(max=1., min=0.).transpose((0, 2, 1)).reshape(B, K, H, W).copy() return torch.from_numpy(ground_truth).to(y.device), torch.from_numpy(ground_false).to(y.device) class RegressionDisparity(nn.Module): """ Regression Disparity proposed by `Regressive Domain Adaptation for Unsupervised Keypoint Detection (CVPR 2021) `_. Args: pseudo_label_generator (PseudoLabelGenerator2d): generate ground truth heatmap and ground false heatmap from a prediction. criterion (torch.nn.Module): the loss function to calculate distance between two predictions. Inputs: - y: output by the main head - y_adv: output by the adversarial head - weight (optional): instance weights - mode (str): whether minimize the disparity or maximize the disparity. Choices includes ``min``, ``max``. Default: ``min``. Shape: - y: :math:`(minibatch, K, H, W)` where K means the number of keypoints, H and W is the height and width of the heatmap respectively. - y_adv: :math:`(minibatch, K, H, W)` - weight: :math:`(minibatch, K)`. - Output: depends on the ``criterion``. Examples:: >>> num_keypoints = 5 >>> batch_size = 10 >>> H = W = 64 >>> pseudo_label_generator = PseudoLabelGenerator2d(num_keypoints) >>> from tllibvision.models.keypoint_detection.loss import JointsKLLoss >>> loss = RegressionDisparity(pseudo_label_generator, JointsKLLoss()) >>> # output from source domain and target domain >>> y_s, y_t = torch.randn(batch_size, num_keypoints, H, W), torch.randn(batch_size, num_keypoints, H, W) >>> # adversarial output from source domain and target domain >>> y_s_adv, y_t_adv = torch.randn(batch_size, num_keypoints, H, W), torch.randn(batch_size, num_keypoints, H, W) >>> # minimize regression disparity on source domain >>> output = loss(y_s, y_s_adv, mode='min') >>> # maximize regression disparity on target domain >>> output = loss(y_t, y_t_adv, mode='max') """ def __init__(self, pseudo_label_generator: PseudoLabelGenerator2d, criterion: nn.Module): super(RegressionDisparity, self).__init__() self.criterion = criterion self.pseudo_label_generator = pseudo_label_generator def forward(self, y, y_adv, weight=None, mode='min'): assert mode in ['min', 'max'] ground_truth, ground_false = self.pseudo_label_generator(y.detach()) self.ground_truth = ground_truth self.ground_false = ground_false if mode == 'min': return self.criterion(y_adv, ground_truth, weight) else: return self.criterion(y_adv, ground_false, weight) class PoseResNet2d(nn.Module): """ Pose ResNet for RegDA has one backbone, one upsampling, while two regression heads. Args: backbone (torch.nn.Module): Backbone to extract 2-d features from data upsampling (torch.nn.Module): Layer to upsample image feature to heatmap size feature_dim (int): The dimension of the features from upsampling layer. num_keypoints (int): Number of keypoints gl (WarmStartGradientLayer): finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: True num_head_layers (int): Number of head layers. Default: 2 Inputs: - x (tensor): input data Outputs: - outputs: logits outputs by the main regressor - outputs_adv: logits outputs by the adversarial regressor Shape: - x: :math:`(minibatch, *)`, same shape as the input of the `backbone`. - outputs, outputs_adv: :math:`(minibatch, K, H, W)`, where K means the number of keypoints. .. note:: Remember to call function `step()` after function `forward()` **during training phase**! For instance, >>> # x is inputs, model is an PoseResNet >>> outputs, outputs_adv = model(x) >>> model.step() """ def __init__(self, backbone, upsampling, feature_dim, num_keypoints, gl: Optional[WarmStartGradientLayer] = None, finetune: Optional[bool] = True, num_head_layers=2): super(PoseResNet2d, self).__init__() self.backbone = backbone self.upsampling = upsampling self.head = self._make_head(num_head_layers, feature_dim, num_keypoints) self.head_adv = self._make_head(num_head_layers, feature_dim, num_keypoints) self.finetune = finetune self.gl_layer = WarmStartGradientLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False) if gl is None else gl @staticmethod def _make_head(num_layers, channel_dim, num_keypoints): layers = [] for i in range(num_layers-1): layers.extend([ nn.Conv2d(channel_dim, channel_dim, 3, 1, 1), nn.BatchNorm2d(channel_dim), nn.ReLU(), ]) layers.append( nn.Conv2d( in_channels=channel_dim, out_channels=num_keypoints, kernel_size=1, stride=1, padding=0 ) ) layers = nn.Sequential(*layers) for m in layers.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.001) nn.init.constant_(m.bias, 0) return layers def forward(self, x): x = self.backbone(x) f = self.upsampling(x) f_adv = self.gl_layer(f) y = self.head(f) y_adv = self.head_adv(f_adv) if self.training: return y, y_adv else: return y def get_parameters(self, lr=1.): return [ {'params': self.backbone.parameters(), 'lr': 0.1 * lr if self.finetune else lr}, {'params': self.upsampling.parameters(), 'lr': lr}, {'params': self.head.parameters(), 'lr': lr}, {'params': self.head_adv.parameters(), 'lr': lr}, ] def step(self): """Call step() each iteration during training. Will increase :math:`\lambda` in GL layer. """ self.gl_layer.step() ================================================ FILE: tllib/alignment/rsd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn import torch class RepresentationSubspaceDistance(nn.Module): """ `Representation Subspace Distance (ICML 2021) `_ Args: trade_off (float): The trade-off value between Representation Subspace Distance and Base Mismatch Penalization. Default: 0.1 Inputs: - f_s (tensor): feature representations on source domain, :math:`f^s` - f_t (tensor): feature representations on target domain, :math:`f^t` """ def __init__(self, trade_off=0.1): super(RepresentationSubspaceDistance, self).__init__() self.trade_off = trade_off def forward(self, f_s, f_t): U_s, _, _ = torch.svd(f_s.t()) U_t, _, _ = torch.svd(f_t.t()) P_s, cosine, P_t = torch.svd(torch.mm(U_s.t(), U_t)) sine = torch.sqrt(1 - torch.pow(cosine, 2)) rsd = torch.norm(sine, 1) # Representation Subspace Distance bmp = torch.norm(torch.abs(P_s) - torch.abs(P_t), 2) # Base Mismatch Penalization return rsd + self.trade_off * bmp ================================================ FILE: tllib/modules/__init__.py ================================================ from .classifier import * from .regressor import * from .grl import * from .domain_discriminator import * from .kernels import * from .entropy import * __all__ = ['classifier', 'regressor', 'grl', 'kernels', 'domain_discriminator', 'entropy'] ================================================ FILE: tllib/modules/classifier.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Tuple, Optional, List, Dict import torch.nn as nn import torch __all__ = ['Classifier'] class Classifier(nn.Module): """A generic Classifier class for domain adaptation. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data num_classes (int): Number of classes bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1 head (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default finetune (bool): Whether finetune the classifier or train from scratch. Default: True .. note:: Different classifiers are used in different domain adaptation algorithms to achieve better accuracy respectively, and we provide a suggested `Classifier` for different algorithms. Remember they are not the core of algorithms. You can implement your own `Classifier` and combine it with the domain adaptation algorithm in this algorithm library. .. note:: The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`. Inputs: - x (tensor): input data fed to `backbone` Outputs: - predictions: classifier's predictions - features: features after `bottleneck` layer and before `head` layer Shape: - Inputs: (minibatch, *) where * means, any number of additional dimensions - predictions: (minibatch, `num_classes`) - features: (minibatch, `features_dim`) """ def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None, bottleneck_dim: Optional[int] = -1, head: Optional[nn.Module] = None, finetune=True, pool_layer=None): super(Classifier, self).__init__() self.backbone = backbone self.num_classes = num_classes if pool_layer is None: self.pool_layer = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() ) else: self.pool_layer = pool_layer if bottleneck is None: self.bottleneck = nn.Identity() self._features_dim = backbone.out_features else: self.bottleneck = bottleneck assert bottleneck_dim > 0 self._features_dim = bottleneck_dim if head is None: self.head = nn.Linear(self._features_dim, num_classes) else: self.head = head self.finetune = finetune @property def features_dim(self) -> int: """The dimension of features before the final `head` layer""" return self._features_dim def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """""" f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) predictions = self.head(f) if self.training: return predictions, f else: return predictions def get_parameters(self, base_lr=1.0) -> List[Dict]: """A parameter list which decides optimization hyper-parameters, such as the relative learning rate of each layer """ params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head.parameters(), "lr": 1.0 * base_lr}, ] return params class ImageClassifier(Classifier): pass ================================================ FILE: tllib/modules/domain_discriminator.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import List, Dict import torch.nn as nn __all__ = ['DomainDiscriminator'] class DomainDiscriminator(nn.Sequential): r"""Domain discriminator model from `Domain-Adversarial Training of Neural Networks (ICML 2015) `_ Distinguish whether the input features come from the source domain or the target domain. The source domain label is 1 and the target domain label is 0. Args: in_feature (int): dimension of the input feature hidden_size (int): dimension of the hidden features batch_norm (bool): whether use :class:`~torch.nn.BatchNorm1d`. Use :class:`~torch.nn.Dropout` if ``batch_norm`` is False. Default: True. Shape: - Inputs: (minibatch, `in_feature`) - Outputs: :math:`(minibatch, 1)` """ def __init__(self, in_feature: int, hidden_size: int, batch_norm=True, sigmoid=True): if sigmoid: final_layer = nn.Sequential( nn.Linear(hidden_size, 1), nn.Sigmoid() ) else: final_layer = nn.Linear(hidden_size, 2) if batch_norm: super(DomainDiscriminator, self).__init__( nn.Linear(in_feature, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.BatchNorm1d(hidden_size), nn.ReLU(), final_layer ) else: super(DomainDiscriminator, self).__init__( nn.Linear(in_feature, hidden_size), nn.ReLU(inplace=True), nn.Dropout(0.5), nn.Linear(hidden_size, hidden_size), nn.ReLU(inplace=True), nn.Dropout(0.5), final_layer ) def get_parameters(self) -> List[Dict]: return [{"params": self.parameters(), "lr": 1.}] ================================================ FILE: tllib/modules/entropy.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch def entropy(predictions: torch.Tensor, reduction='none') -> torch.Tensor: r"""Entropy of prediction. The definition is: .. math:: entropy(p) = - \sum_{c=1}^C p_c \log p_c where C is number of classes. Args: predictions (tensor): Classifier predictions. Expected to contain raw, normalized scores for each class reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output. Default: ``'mean'`` Shape: - predictions: :math:`(minibatch, C)` where C means the number of classes. - Output: :math:`(minibatch, )` by default. If :attr:`reduction` is ``'mean'``, then scalar. """ epsilon = 1e-5 H = -predictions * torch.log(predictions + epsilon) H = H.sum(dim=1) if reduction == 'mean': return H.mean() else: return H ================================================ FILE: tllib/modules/gl.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, Any, Tuple import numpy as np import torch.nn as nn from torch.autograd import Function import torch class GradientFunction(Function): @staticmethod def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: ctx.coeff = coeff output = input * 1.0 return output @staticmethod def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: return grad_output * ctx.coeff, None class WarmStartGradientLayer(nn.Module): """Warm Start Gradient Layer :math:`\mathcal{R}(x)` with warm start The forward and backward behaviours are: .. math:: \mathcal{R}(x) = x, \dfrac{ d\mathcal{R}} {dx} = \lambda I. :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule: .. math:: \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo where :math:`i` is the iteration step. Parameters: - **alpha** (float, optional): :math:`α`. Default: 1.0 - **lo** (float, optional): Initial value of :math:`\lambda`. Default: 0.0 - **hi** (float, optional): Final value of :math:`\lambda`. Default: 1.0 - **max_iters** (int, optional): :math:`N`. Default: 1000 - **auto_step** (bool, optional): If True, increase :math:`i` each time `forward` is called. Otherwise use function `step` to increase :math:`i`. Default: False """ def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1., max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False): super(WarmStartGradientLayer, self).__init__() self.alpha = alpha self.lo = lo self.hi = hi self.iter_num = 0 self.max_iters = max_iters self.auto_step = auto_step def forward(self, input: torch.Tensor) -> torch.Tensor: """""" coeff = np.float( 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) - (self.hi - self.lo) + self.lo ) if self.auto_step: self.step() return GradientFunction.apply(input, coeff) def step(self): """Increase iteration number :math:`i` by 1""" self.iter_num += 1 ================================================ FILE: tllib/modules/grl.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, Any, Tuple import numpy as np import torch.nn as nn from torch.autograd import Function import torch class GradientReverseFunction(Function): @staticmethod def forward(ctx: Any, input: torch.Tensor, coeff: Optional[float] = 1.) -> torch.Tensor: ctx.coeff = coeff output = input * 1.0 return output @staticmethod def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, Any]: return grad_output.neg() * ctx.coeff, None class GradientReverseLayer(nn.Module): def __init__(self): super(GradientReverseLayer, self).__init__() def forward(self, *input): return GradientReverseFunction.apply(*input) class WarmStartGradientReverseLayer(nn.Module): """Gradient Reverse Layer :math:`\mathcal{R}(x)` with warm start The forward and backward behaviours are: .. math:: \mathcal{R}(x) = x, \dfrac{ d\mathcal{R}} {dx} = - \lambda I. :math:`\lambda` is initiated at :math:`lo` and is gradually changed to :math:`hi` using the following schedule: .. math:: \lambda = \dfrac{2(hi-lo)}{1+\exp(- α \dfrac{i}{N})} - (hi-lo) + lo where :math:`i` is the iteration step. Args: alpha (float, optional): :math:`α`. Default: 1.0 lo (float, optional): Initial value of :math:`\lambda`. Default: 0.0 hi (float, optional): Final value of :math:`\lambda`. Default: 1.0 max_iters (int, optional): :math:`N`. Default: 1000 auto_step (bool, optional): If True, increase :math:`i` each time `forward` is called. Otherwise use function `step` to increase :math:`i`. Default: False """ def __init__(self, alpha: Optional[float] = 1.0, lo: Optional[float] = 0.0, hi: Optional[float] = 1., max_iters: Optional[int] = 1000., auto_step: Optional[bool] = False): super(WarmStartGradientReverseLayer, self).__init__() self.alpha = alpha self.lo = lo self.hi = hi self.iter_num = 0 self.max_iters = max_iters self.auto_step = auto_step def forward(self, input: torch.Tensor) -> torch.Tensor: """""" coeff = np.float( 2.0 * (self.hi - self.lo) / (1.0 + np.exp(-self.alpha * self.iter_num / self.max_iters)) - (self.hi - self.lo) + self.lo ) if self.auto_step: self.step() return GradientReverseFunction.apply(input, coeff) def step(self): """Increase iteration number :math:`i` by 1""" self.iter_num += 1 ================================================ FILE: tllib/modules/kernels.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import torch import torch.nn as nn __all__ = ['GaussianKernel'] class GaussianKernel(nn.Module): r"""Gaussian Kernel Matrix Gaussian Kernel k is defined by .. math:: k(x_1, x_2) = \exp \left( - \dfrac{\| x_1 - x_2 \|^2}{2\sigma^2} \right) where :math:`x_1, x_2 \in R^d` are 1-d tensors. Gaussian Kernel Matrix K is defined on input group :math:`X=(x_1, x_2, ..., x_m),` .. math:: K(X)_{i,j} = k(x_i, x_j) Also by default, during training this layer keeps running estimates of the mean of L2 distances, which are then used to set hyperparameter :math:`\sigma`. Mathematically, the estimation is :math:`\sigma^2 = \dfrac{\alpha}{n^2}\sum_{i,j} \| x_i - x_j \|^2`. If :attr:`track_running_stats` is set to ``False``, this layer then does not keep running estimates, and use a fixed :math:`\sigma` instead. Args: sigma (float, optional): bandwidth :math:`\sigma`. Default: None track_running_stats (bool, optional): If ``True``, this module tracks the running mean of :math:`\sigma^2`. Otherwise, it won't track such statistics and always uses fix :math:`\sigma^2`. Default: ``True`` alpha (float, optional): :math:`\alpha` which decides the magnitude of :math:`\sigma^2` when track_running_stats is set to ``True`` Inputs: - X (tensor): input group :math:`X` Shape: - Inputs: :math:`(minibatch, F)` where F means the dimension of input features. - Outputs: :math:`(minibatch, minibatch)` """ def __init__(self, sigma: Optional[float] = None, track_running_stats: Optional[bool] = True, alpha: Optional[float] = 1.): super(GaussianKernel, self).__init__() assert track_running_stats or sigma is not None self.sigma_square = torch.tensor(sigma * sigma) if sigma is not None else None self.track_running_stats = track_running_stats self.alpha = alpha def forward(self, X: torch.Tensor) -> torch.Tensor: l2_distance_square = ((X.unsqueeze(0) - X.unsqueeze(1)) ** 2).sum(2) if self.track_running_stats: self.sigma_square = self.alpha * torch.mean(l2_distance_square.detach()) return torch.exp(-l2_distance_square / (2 * self.sigma_square)) ================================================ FILE: tllib/modules/loss.py ================================================ import torch.nn as nn import torch import torch.nn.functional as F # version 1: use torch.autograd class LabelSmoothSoftmaxCEV1(nn.Module): ''' Adapted from https://github.com/CoinCheung/pytorch-loss ''' def __init__(self, lb_smooth=0.1, reduction='mean', ignore_index=-1): super(LabelSmoothSoftmaxCEV1, self).__init__() self.lb_smooth = lb_smooth self.reduction = reduction self.lb_ignore = ignore_index self.log_softmax = nn.LogSoftmax(dim=1) def forward(self, input, target): ''' Same usage method as nn.CrossEntropyLoss: >>> criteria = LabelSmoothSoftmaxCEV1() >>> logits = torch.randn(8, 19, 384, 384) # nchw, float/half >>> lbs = torch.randint(0, 19, (8, 384, 384)) # nhw, int64_t >>> loss = criteria(logits, lbs) ''' # overcome ignored label logits = input.float() # use fp32 to avoid nan with torch.no_grad(): num_classes = logits.size(1) label = target.clone().detach() ignore = label.eq(self.lb_ignore) n_valid = ignore.eq(0).sum() label[ignore] = 0 lb_pos, lb_neg = 1. - self.lb_smooth, self.lb_smooth / num_classes lb_one_hot = torch.empty_like(logits).fill_( lb_neg).scatter_(1, label.unsqueeze(1), lb_pos).detach() logs = self.log_softmax(logits) loss = -torch.sum(logs * lb_one_hot, dim=1) loss[ignore] = 0 if self.reduction == 'mean': loss = loss.sum() / n_valid if self.reduction == 'sum': loss = loss.sum() return loss class KnowledgeDistillationLoss(nn.Module): """Knowledge Distillation Loss. Args: T (double): Temperature. Default: 1. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'batchmean'`` Inputs: - y_student (tensor): logits output of the student - y_teacher (tensor): logits output of the teacher Shape: - y_student: (minibatch, `num_classes`) - y_teacher: (minibatch, `num_classes`) """ def __init__(self, T=1., reduction='batchmean'): super(KnowledgeDistillationLoss, self).__init__() self.T = T self.kl = nn.KLDivLoss(reduction=reduction) def forward(self, y_student, y_teacher): """""" return self.kl(F.log_softmax(y_student / self.T, dim=-1), F.softmax(y_teacher / self.T, dim=-1)) ================================================ FILE: tllib/modules/regressor.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Tuple, Optional, List, Dict import torch.nn as nn import torch __all__ = ['Regressor'] class Regressor(nn.Module): """A generic Regressor class for domain adaptation. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data num_factors (int): Number of factors bottleneck (torch.nn.Module, optional): Any bottleneck layer. Use no bottleneck by default bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: -1 head (torch.nn.Module, optional): Any classifier head. Use `nn.Linear` by default finetune (bool): Whether finetune the classifier or train from scratch. Default: True .. note:: The learning rate of this regressor is set 10 times to that of the feature extractor for better accuracy by default. If you have other optimization strategies, please over-ride :meth:`~Regressor.get_parameters`. Inputs: - x (tensor): input data fed to `backbone` Outputs: - predictions: regressor's predictions - features: features after `bottleneck` layer and before `head` layer Shape: - Inputs: (minibatch, *) where * means, any number of additional dimensions - predictions: (minibatch, `num_factors`) - features: (minibatch, `features_dim`) """ def __init__(self, backbone: nn.Module, num_factors: int, bottleneck: Optional[nn.Module] = None, bottleneck_dim=-1, head: Optional[nn.Module] = None, finetune=True): super(Regressor, self).__init__() self.backbone = backbone self.num_factors = num_factors if bottleneck is None: feature_dim = backbone.out_features self.bottleneck = nn.Sequential( nn.Conv2d(feature_dim, feature_dim, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(feature_dim, feature_dim), nn.ReLU(), nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() ) self._features_dim = feature_dim else: self.bottleneck = bottleneck assert bottleneck_dim > 0 self._features_dim = bottleneck_dim if head is None: self.head = nn.Sequential( nn.Linear(self._features_dim, num_factors), nn.Sigmoid() ) else: self.head = head self.finetune = finetune @property def features_dim(self) -> int: """The dimension of features before the final `head` layer""" return self._features_dim def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """""" f = self.backbone(x) f = self.bottleneck(f) predictions = self.head(f) if self.training: return predictions, f else: return predictions def get_parameters(self, base_lr=1.0) -> List[Dict]: """A parameter list which decides optimization hyper-parameters, such as the relative learning rate of each layer """ params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head.parameters(), "lr": 1.0 * base_lr}, ] return params ================================================ FILE: tllib/normalization/__init__.py ================================================ ================================================ FILE: tllib/normalization/afn.py ================================================ """ Modified from https://github.com/jihanyang/AFN @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from typing import Optional, List, Dict import torch import torch.nn as nn import math from tllib.modules.classifier import Classifier as ClassfierBase class AdaptiveFeatureNorm(nn.Module): r""" The `Stepwise Adaptive Feature Norm loss (ICCV 2019) `_ Instead of using restrictive scalar R to match the corresponding feature norm, Stepwise Adaptive Feature Norm is used in order to learn task-specific features with large norms in a progressive manner. We denote parameters of backbone :math:`G` as :math:`\theta_g`, parameters of bottleneck :math:`F_f` as :math:`\theta_f` , parameters of classifier head :math:`F_y` as :math:`\theta_y`, and features extracted from sample :math:`x_i` as :math:`h(x_i;\theta)`. Full loss is calculated as follows .. math:: L(\theta_g,\theta_f,\theta_y)=\frac{1}{n_s}\sum_{(x_i,y_i)\in D_s}L_y(x_i,y_i)+\frac{\lambda}{n_s+n_t} \sum_{x_i\in D_s\cup D_t}L_d(h(x_i;\theta_0)+\Delta_r,h(x_i;\theta))\\ where :math:`L_y` denotes classification loss, :math:`L_d` denotes norm loss, :math:`\theta_0` and :math:`\theta` represent the updated and updating model parameters in the last and current iterations respectively. Args: delta (float): positive residual scalar to control the feature norm enlargement. Inputs: - f (tensor): feature representations on source or target domain. Shape: - f: :math:`(N, F)` where F means the dimension of input features. - Outputs: scalar. Examples:: >>> adaptive_feature_norm = AdaptiveFeatureNorm(delta=1) >>> f_s = torch.randn(32, 1000) >>> f_t = torch.randn(32, 1000) >>> norm_loss = adaptive_feature_norm(f_s) + adaptive_feature_norm(f_t) """ def __init__(self, delta): super(AdaptiveFeatureNorm, self).__init__() self.delta = delta def forward(self, f: torch.Tensor) -> torch.Tensor: radius = f.norm(p=2, dim=1).detach() assert radius.requires_grad == False radius = radius + self.delta loss = ((f.norm(p=2, dim=1) - radius) ** 2).mean() return loss class Block(nn.Module): r""" Basic building block for Image Classifier with structure: FC-BN-ReLU-Dropout. We use :math:`L_2` preserved dropout layers. Given mask probability :math:`p`, input :math:`x_k`, generated mask :math:`a_k`, vanilla dropout layers calculate .. math:: \hat{x}_k = a_k\frac{1}{1-p}x_k\\ While in :math:`L_2` preserved dropout layers .. math:: \hat{x}_k = a_k\frac{1}{\sqrt{1-p}}x_k\\ Args: in_features (int): Dimension of input features bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1000 dropout_p (float, optional): dropout probability. Default: 0.5 """ def __init__(self, in_features: int, bottleneck_dim: Optional[int] = 1000, dropout_p: Optional[float] = 0.5): super(Block, self).__init__() self.fc = nn.Linear(in_features, bottleneck_dim) self.bn = nn.BatchNorm1d(bottleneck_dim, affine=True) self.relu = nn.ReLU(inplace=True) self.dropout = nn.Dropout(dropout_p) self.dropout_p = dropout_p def forward(self, x: torch.Tensor) -> torch.Tensor: f = self.fc(x) f = self.bn(f) f = self.relu(f) f = self.dropout(f) if self.training: f.mul_(math.sqrt(1 - self.dropout_p)) return f class ImageClassifier(ClassfierBase): r""" ImageClassifier for AFN. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data num_classes (int): Number of classes num_blocks (int, optional): Number of basic blocks. Default: 1 bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. Default: 1000 dropout_p (float, optional): dropout probability. Default: 0.5 """ def __init__(self, backbone: nn.Module, num_classes: int, num_blocks: Optional[int] = 1, bottleneck_dim: Optional[int] = 1000, dropout_p: Optional[float] = 0.5, **kwargs): assert num_blocks >= 1 layers = [nn.Sequential( Block(backbone.out_features, bottleneck_dim, dropout_p) )] for _ in range(num_blocks - 1): layers.append(Block(bottleneck_dim, bottleneck_dim, dropout_p)) bottleneck = nn.Sequential(*layers) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) # init parameters for bottleneck and head for m in self.bottleneck.modules(): if isinstance(m, nn.BatchNorm1d): m.weight.data.normal_(1.0, 0.01) m.bias.data.fill_(0) if isinstance(m, nn.Linear): m.weight.data.normal_(0.0, 0.01) m.bias.data.normal_(0.0, 0.01) for m in self.head.modules(): if isinstance(m, nn.Linear): m.weight.data.normal_(0.0, 0.01) m.bias.data.normal_(0.0, 0.01) def get_parameters(self, base_lr=1.0) -> List[Dict]: params = [ {"params": self.backbone.parameters()}, {"params": self.bottleneck.parameters(), "momentum": 0.9}, {"params": self.head.parameters(), "momentum": 0.9}, ] return params ================================================ FILE: tllib/normalization/ibn.py ================================================ """ Modified from https://github.com/XingangPan/IBN-Net @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import math import torch import torch.nn as nn __all__ = ['resnet18_ibn_a', 'resnet18_ibn_b', 'resnet34_ibn_a', 'resnet34_ibn_b', 'resnet50_ibn_a', 'resnet50_ibn_b', 'resnet101_ibn_a', 'resnet101_ibn_b'] model_urls = { 'resnet18_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth', 'resnet34_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth', 'resnet50_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth', 'resnet101_ibn_a': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth', 'resnet18_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_b-bc2f3c11.pth', 'resnet34_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_b-04134c37.pth', 'resnet50_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_b-9ca61e85.pth', 'resnet101_ibn_b': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_b-c55f6dba.pth', } class InstanceBatchNorm2d(nn.Module): r"""Instance-Batch Normalization layer from `Two at Once: Enhancing Learning and Generalization Capacities via IBN-Net (ECCV 2018) `_. Given input feature map :math:`f\_input` of dimension :math:`(C,H,W)`, we first split :math:`f\_input` into two parts along `channel` dimension. They are denoted as :math:`f_1` of dimension :math:`(C_1,H,W)` and :math:`f_2` of dimension :math:`(C_2,H,W)`, where :math:`C_1+C_2=C`. Then we pass :math:`f_1` and :math:`f_2` through IN and BN layer, respectively, to get :math:`IN(f_1)` and :math:`BN(f_2)`. Last, we concat them along `channel` dimension to create :math:`f\_output=concat(IN(f_1), BN(f_2))`. Args: planes (int): Number of channels for the input tensor ratio (float): Ratio of instance normalization in the IBN layer """ def __init__(self, planes, ratio=0.5): super(InstanceBatchNorm2d, self).__init__() self.half = int(planes * ratio) self.IN = nn.InstanceNorm2d(self.half, affine=True) self.BN = nn.BatchNorm2d(planes - self.half) def forward(self, x): split = torch.split(x, self.half, 1) out1 = self.IN(split[0].contiguous()) out2 = self.BN(split[1].contiguous()) out = torch.cat((out1, out2), 1) return out class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) if ibn == 'a': self.bn1 = InstanceBatchNorm2d(planes) else: self.bn1 = nn.BatchNorm2d(planes) self.relu = nn.ReLU(inplace=True) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.IN = nn.InstanceNorm2d(planes, affine=True) if ibn == 'b' else None self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual if self.IN is not None: out = self.IN(out) out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, ibn=None, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) if ibn == 'a': self.bn1 = InstanceBatchNorm2d(planes) else: self.bn1 = nn.BatchNorm2d(planes) self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion) self.IN = nn.InstanceNorm2d(planes * 4, affine=True) if ibn == 'b' else None self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual if self.IN is not None: out = self.IN(out) out = self.relu(out) return out class IBNNet(nn.Module): r""" IBNNet without fully connected layer """ def __init__(self, block, layers, ibn_cfg=('a', 'a', 'a', None)): self.inplanes = 64 super(IBNNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) if ibn_cfg[0] == 'b': self.bn1 = nn.InstanceNorm2d(64, affine=True) else: self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], ibn=ibn_cfg[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, ibn=ibn_cfg[1]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, ibn=ibn_cfg[2]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, ibn=ibn_cfg[3]) self._out_features = 512 * block.expansion for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1, ibn=None): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion), ) layers = [] layers.append(block(self.inplanes, planes, None if ibn == 'b' else ibn, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, None if (ibn == 'b' and i < blocks - 1) else ibn)) return nn.Sequential(*layers) def forward(self, x): """""" x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x @property def out_features(self) -> int: """The dimension of output features""" return self._out_features def resnet18_ibn_a(pretrained=False): """Constructs a ResNet-18-IBN-a model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = IBNNet(block=BasicBlock, layers=[2, 2, 2, 2], ibn_cfg=('a', 'a', 'a', None)) if pretrained: model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_a']), strict=False) return model def resnet34_ibn_a(pretrained=False): """Constructs a ResNet-34-IBN-a model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = IBNNet(block=BasicBlock, layers=[3, 4, 6, 3], ibn_cfg=('a', 'a', 'a', None)) if pretrained: model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_a']), strict=False) return model def resnet50_ibn_a(pretrained=False): """Constructs a ResNet-50-IBN-a model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = IBNNet(block=Bottleneck, layers=[3, 4, 6, 3], ibn_cfg=('a', 'a', 'a', None)) if pretrained: model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_a']), strict=False) return model def resnet101_ibn_a(pretrained=False): """Constructs a ResNet-101-IBN-a model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = IBNNet(block=Bottleneck, layers=[3, 4, 23, 3], ibn_cfg=('a', 'a', 'a', None)) if pretrained: model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_a']), strict=False) return model def resnet18_ibn_b(pretrained=False): """Constructs a ResNet-18-IBN-b model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = IBNNet(block=BasicBlock, layers=[2, 2, 2, 2], ibn_cfg=('b', 'b', None, None)) if pretrained: model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet18_ibn_b']), strict=False) return model def resnet34_ibn_b(pretrained=False): """Constructs a ResNet-34-IBN-b model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = IBNNet(block=BasicBlock, layers=[3, 4, 6, 3], ibn_cfg=('b', 'b', None, None)) if pretrained: model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet34_ibn_b']), strict=False) return model def resnet50_ibn_b(pretrained=False): """Constructs a ResNet-50-IBN-b model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = IBNNet(block=Bottleneck, layers=[3, 4, 6, 3], ibn_cfg=('b', 'b', None, None)) if pretrained: model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet50_ibn_b']), strict=False) return model def resnet101_ibn_b(pretrained=False): """Constructs a ResNet-101-IBN-b model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = IBNNet(block=Bottleneck, layers=[3, 4, 23, 3], ibn_cfg=('b', 'b', None, None)) if pretrained: model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls['resnet101_ibn_b']), strict=False) return model ================================================ FILE: tllib/normalization/mixstyle/__init__.py ================================================ """ Modified from https://github.com/KaiyangZhou/mixstyle-release @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import torch import torch.nn as nn class MixStyle(nn.Module): r"""MixStyle module from `DOMAIN GENERALIZATION WITH MIXSTYLE (ICLR 2021) `_. Given input :math:`x`, we first compute mean :math:`\mu(x)` and standard deviation :math:`\sigma(x)` across spatial dimension. Then we permute :math:`x` and get :math:`\tilde{x}`, corresponding mean :math:`\mu(\tilde{x})` and standard deviation :math:`\sigma(\tilde{x})`. `MixUp` is performed using mean and standard deviation .. math:: \gamma_{mix} = \lambda\sigma(x) + (1-\lambda)\sigma(\tilde{x}) .. math:: \beta_{mix} = \lambda\mu(x) + (1-\lambda)\mu(\tilde{x}) where :math:`\lambda` is instance-wise weight sampled from `Beta distribution`. MixStyle is then .. math:: MixStyle(x) = \gamma_{mix}\frac{x-\mu(x)}{\sigma(x)} + \beta_{mix} Args: p (float): probability of using MixStyle. alpha (float): parameter of the `Beta distribution`. eps (float): scaling parameter to avoid numerical issues. """ def __init__(self, p=0.5, alpha=0.1, eps=1e-6): super().__init__() self.p = p self.beta = torch.distributions.Beta(alpha, alpha) self.eps = eps self.alpha = alpha def forward(self, x): if not self.training: return x if random.random() > self.p: return x batch_size = x.size(0) mu = x.mean(dim=[2, 3], keepdim=True) var = x.var(dim=[2, 3], keepdim=True) sigma = (var + self.eps).sqrt() mu, sigma = mu.detach(), sigma.detach() x_normed = (x - mu) / sigma interpolation = self.beta.sample((batch_size, 1, 1, 1)) interpolation = interpolation.to(x.device) # split into two halves and swap the order perm = torch.arange(batch_size - 1, -1, -1) # inverse index perm_b, perm_a = perm.chunk(2) perm_b = perm_b[torch.randperm(batch_size // 2)] perm_a = perm_a[torch.randperm(batch_size // 2)] perm = torch.cat([perm_b, perm_a], 0) mu_perm, sigma_perm = mu[perm], sigma[perm] mu_mix = mu * interpolation + mu_perm * (1 - interpolation) sigma_mix = sigma * interpolation + sigma_perm * (1 - interpolation) return x_normed * sigma_mix + mu_mix ================================================ FILE: tllib/normalization/mixstyle/resnet.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from . import MixStyle from tllib.vision.models.reid.resnet import ReidResNet from tllib.vision.models.resnet import ResNet, load_state_dict_from_url, model_urls, BasicBlock, Bottleneck __all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101'] def _resnet_with_mix_style(arch, block, layers, pretrained, progress, mix_layers=None, mix_p=0.5, mix_alpha=0.1, resnet_class=ResNet, **kwargs): """Construct `ResNet` with MixStyle modules. Given any resnet architecture **resnet_class** that contains conv1, bn1, relu, maxpool, layer1-4, this function define a new class that inherits from **resnet_class** and inserts MixStyle module during forward pass. Although MixStyle Module can be inserted anywhere, original paper finds it better to place MixStyle after layer1-3. Our implementation follows this idea, but you are free to modify this function to try other possibilities. Args: arch (str): resnet architecture (resnet50 for example) block (class): class of resnet block layers (list): depth list of each block pretrained (bool): if True, load imagenet pre-trained model parameters progress (bool): whether or not to display a progress bar to stderr mix_layers (list): layers to insert MixStyle module after mix_p (float): probability to activate MixStyle during forward pass mix_alpha (float): parameter alpha for beta distribution resnet_class (class): base resnet class to inherit from """ if mix_layers is None: mix_layers = [] available_resnet_class = [ResNet, ReidResNet] assert resnet_class in available_resnet_class class ResNetWithMixStyleModule(resnet_class): def __init__(self, mix_layers, mix_p=0.5, mix_alpha=0.1, *args, **kwargs): super(ResNetWithMixStyleModule, self).__init__(*args, **kwargs) self.mixStyleModule = MixStyle(p=mix_p, alpha=mix_alpha) for layer in mix_layers: assert layer in ['layer1', 'layer2', 'layer3'] self.apply_layers = mix_layers def forward(self, x): x = self.conv1(x) x = self.bn1(x) # turn on relu activation here **except for** reid tasks if resnet_class != ReidResNet: x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) if 'layer1' in self.apply_layers: x = self.mixStyleModule(x) x = self.layer2(x) if 'layer2' in self.apply_layers: x = self.mixStyleModule(x) x = self.layer3(x) if 'layer3' in self.apply_layers: x = self.mixStyleModule(x) x = self.layer4(x) return x model = ResNetWithMixStyleModule(mix_layers=mix_layers, mix_p=mix_p, mix_alpha=mix_alpha, block=block, layers=layers, **kwargs) if pretrained: model_dict = model.state_dict() pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress) # remove keys from pretrained dict that doesn't appear in model dict pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model.load_state_dict(pretrained_dict, strict=False) return model def resnet18(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-18 model with MixStyle. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet_with_mix_style('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def resnet34(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-34 model with MixStyle. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet_with_mix_style('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet50(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-50 model with MixStyle. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet_with_mix_style('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet101(pretrained=False, progress=True, **kwargs): """Constructs a ResNet-101 model with MixStyle. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet_with_mix_style('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) ================================================ FILE: tllib/normalization/mixstyle/sampler.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import random import copy from torch.utils.data.dataset import ConcatDataset from torch.utils.data.sampler import Sampler class RandomDomainMultiInstanceSampler(Sampler): r"""Randomly sample :math:`N` domains, then randomly select :math:`P` instances in each domain, for each instance, randomly select :math:`K` images to form a mini-batch of size :math:`N\times P\times K`. Args: dataset (ConcatDataset): dataset that contains data from multiple domains batch_size (int): mini-batch size (:math:`N\times P\times K` here) n_domains_per_batch (int): number of domains to select in a single mini-batch (:math:`N` here) num_instances (int): number of instances to select in each domain (:math:`K` here) """ def __init__(self, dataset, batch_size, n_domains_per_batch, num_instances): super(Sampler, self).__init__() self.dataset = dataset self.sample_idxes_per_domain = {} for idx, (_, _, domain_id) in enumerate(self.dataset): if domain_id not in self.sample_idxes_per_domain: self.sample_idxes_per_domain[domain_id] = [] self.sample_idxes_per_domain[domain_id].append(idx) self.n_domains_in_dataset = len(self.sample_idxes_per_domain) self.n_domains_per_batch = n_domains_per_batch assert self.n_domains_in_dataset >= self.n_domains_per_batch assert batch_size % n_domains_per_batch == 0 self.batch_size_per_domain = batch_size // n_domains_per_batch assert self.batch_size_per_domain % num_instances == 0 self.num_instances = num_instances self.num_classes_per_domain = self.batch_size_per_domain // num_instances self.length = len(list(self.__iter__())) def __iter__(self): sample_idxes_per_domain = copy.deepcopy(self.sample_idxes_per_domain) domain_idxes = [idx for idx in range(self.n_domains_in_dataset)] final_idxes = [] stop_flag = False while not stop_flag: selected_domains = random.sample(domain_idxes, self.n_domains_per_batch) for domain in selected_domains: sample_idxes = sample_idxes_per_domain[domain] selected_idxes = self.sample_multi_instances(sample_idxes) final_idxes.extend(selected_idxes) for idx in selected_idxes: sample_idxes_per_domain[domain].remove(idx) remaining_size = len(sample_idxes_per_domain[domain]) if remaining_size < self.batch_size_per_domain: stop_flag = True return iter(final_idxes) def sample_multi_instances(self, sample_idxes): idxes_per_cls = {} for idx in sample_idxes: _, cls, _ = self.dataset[idx] if cls not in idxes_per_cls: idxes_per_cls[cls] = [] idxes_per_cls[cls].append(idx) cls_list = [cls for cls in idxes_per_cls if len(idxes_per_cls[cls]) >= self.num_instances] if len(cls_list) < self.num_classes_per_domain: return random.sample(sample_idxes, self.batch_size_per_domain) selected_idxes = [] selected_classes = random.sample(cls_list, self.num_classes_per_domain) for cls in selected_classes: selected_idxes.extend(random.sample(idxes_per_cls[cls], self.num_instances)) return selected_idxes def __len__(self): return self.length ================================================ FILE: tllib/normalization/stochnorm.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.parameter import Parameter __all__ = ['StochNorm1d', 'StochNorm2d', 'convert_model'] class _StochNorm(nn.Module): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, p=0.5): super(_StochNorm, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats self.p = p if self.affine: self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.reset_parameters() def reset_parameters(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) if self.affine: self.weight.data.uniform_() self.bias.data.zero_() def _check_input_dim(self, input): return NotImplemented def forward(self, input): self._check_input_dim(input) if self.training: z_0 = F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, False, self.momentum, self.eps) z_1 = F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, True, self.momentum, self.eps) if input.dim() == 2: s = torch.from_numpy( np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features)).float().cuda() elif input.dim() == 3: s = torch.from_numpy( np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features, 1)).float().cuda() elif input.dim() == 4: s = torch.from_numpy( np.random.binomial(n=1, p=self.p, size=self.num_features).reshape(1, self.num_features, 1, 1)).float().cuda() else: raise BaseException() z = (1 - s) * z_0 + s * z_1 else: z = F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, False, self.momentum, self.eps) return z class StochNorm1d(_StochNorm): r"""Applies Stochastic Normalization over a 2D or 3D input (a mini-batch of 1D inputs with optional additional channel dimension) Stochastic Normalization is proposed in `Stochastic Normalization (NIPS 2020) `_ .. math:: \hat{x}_{i,0} = \frac{x_i - \tilde{\mu}}{ \sqrt{\tilde{\sigma} + \epsilon}} \hat{x}_{i,1} = \frac{x_i - \mu}{ \sqrt{\sigma + \epsilon}} \hat{x}_i = (1-s)\cdot \hat{x}_{i,0} + s\cdot \hat{x}_{i,1} y_i = \gamma \hat{x}_i + \beta where :math:`\mu` and :math:`\sigma` are mean and variance of current mini-batch data. :math:`\tilde{\mu}` and :math:`\tilde{\sigma}` are current moving statistics of training data. :math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`. During training, there are two normalization branches. One uses mean and variance of current mini-batch data, while the other uses current moving statistics of the training data as usual batch normalization. During evaluation, the moving statistics is used for normalization. Args: num_features (int): :math:`c` from an expected input of size :math:`(b, c, l)` or :math:`l` from an expected input of size :math:`(b, l)`. eps (float): A value added to the denominator for numerical stability. Default: 1e-5 momentum (float): The value used for the running_mean and running_var computation. Default: 0.1 affine (bool): A boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` track_running_stats (bool): A boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics in both training and eval modes. Default: True p (float): The probability to choose the second branch (usual BN). Default: 0.5 Shape: - Input: :math:`(b, l)` or :math:`(b, c, l)` - Output: :math:`(b, l)` or :math:`(b, c, l)` (same shape as input) """ def _check_input_dim(self, input): if input.dim() != 2 and input.dim() != 3: raise ValueError('expected 2D or 3D input (got {}D input)' .format(input.dim())) class StochNorm2d(_StochNorm): r""" Applies Stochastic Normalization over a 4D input (a mini-batch of 2D inputs with additional channel dimension) Stochastic Normalization is proposed in `Stochastic Normalization (NIPS 2020) `_ .. math:: \hat{x}_{i,0} = \frac{x_i - \tilde{\mu}}{ \sqrt{\tilde{\sigma} + \epsilon}} \hat{x}_{i,1} = \frac{x_i - \mu}{ \sqrt{\sigma + \epsilon}} \hat{x}_i = (1-s)\cdot \hat{x}_{i,0} + s\cdot \hat{x}_{i,1} y_i = \gamma \hat{x}_i + \beta where :math:`\mu` and :math:`\sigma` are mean and variance of current mini-batch data. :math:`\tilde{\mu}` and :math:`\tilde{\sigma}` are current moving statistics of training data. :math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`. During training, there are two normalization branches. One uses mean and variance of current mini-batch data, while the other uses current moving statistics of the training data as usual batch normalization. During evaluation, the moving statistics is used for normalization. Args: num_features (int): :math:`c` from an expected input of size :math:`(b, c, h, w)`. eps (float): A value added to the denominator for numerical stability. Default: 1e-5 momentum (float): The value used for the running_mean and running_var computation. Default: 0.1 affine (bool): A boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` track_running_stats (bool): A boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics in both training and eval modes. Default: True p (float): The probability to choose the second branch (usual BN). Default: 0.5 Shape: - Input: :math:`(b, c, h, w)` - Output: :math:`(b, c, h, w)` (same shape as input) """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) class StochNorm3d(_StochNorm): r""" Applies Stochastic Normalization over a 5D input (a mini-batch of 3D inputs with additional channel dimension) Stochastic Normalization is proposed in `Stochastic Normalization (NIPS 2020) `_ .. math:: \hat{x}_{i,0} = \frac{x_i - \tilde{\mu}}{ \sqrt{\tilde{\sigma} + \epsilon}} \hat{x}_{i,1} = \frac{x_i - \mu}{ \sqrt{\sigma + \epsilon}} \hat{x}_i = (1-s)\cdot \hat{x}_{i,0} + s\cdot \hat{x}_{i,1} y_i = \gamma \hat{x}_i + \beta where :math:`\mu` and :math:`\sigma` are mean and variance of current mini-batch data. :math:`\tilde{\mu}` and :math:`\tilde{\sigma}` are current moving statistics of training data. :math:`s` is a branch-selection variable generated from a Bernoulli distribution, where :math:`P(s=1)=p`. During training, there are two normalization branches. One uses mean and variance of current mini-batch data, while the other uses current moving statistics of the training data as usual batch normalization. During evaluation, the moving statistics is used for normalization. Args: num_features (int): :math:`c` from an expected input of size :math:`(b, c, d, h, w)` eps (float): A value added to the denominator for numerical stability. Default: 1e-5 momentum (float): The value used for the running_mean and running_var computation. Default: 0.1 affine (bool): A boolean value that when set to ``True``, gives the layer learnable affine parameters. Default: ``True`` track_running_stats (bool): A boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics in both training and eval modes. Default: True p (float): The probability to choose the second branch (usual BN). Default: 0.5 Shape: - Input: :math:`(b, c, d, h, w)` - Output: :math:`(b, c, d, h, w)` (same shape as input) """ def _check_input_dim(self, input): if input.dim() != 5: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) def convert_model(module, p): """ Traverses the input module and its child recursively and replaces all instance of BatchNorm to StochNorm. Args: module (torch.nn.Module): The input module needs to be convert to StochNorm model. p (float): The hyper-parameter for StochNorm layer. Returns: The module converted to StochNorm version. """ mod = module for pth_module, stoch_module in zip([torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d], [StochNorm1d, StochNorm2d, StochNorm3d]): if isinstance(module, pth_module): mod = stoch_module(module.num_features, module.eps, module.momentum, module.affine, p) mod.running_mean = module.running_mean mod.running_var = module.running_var if module.affine: mod.weight.data = module.weight.data.clone().detach() mod.bias.data = module.bias.data.clone().detach() for name, child in module.named_children(): mod.add_module(name, convert_model(child, p)) return mod ================================================ FILE: tllib/ranking/__init__.py ================================================ from .logme import log_maximum_evidence from .nce import negative_conditional_entropy from .leep import log_expected_empirical_prediction from .hscore import h_score __all__ = ['log_maximum_evidence', 'negative_conditional_entropy', 'log_expected_empirical_prediction', 'h_score'] ================================================ FILE: tllib/ranking/hscore.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import numpy as np from sklearn.covariance import LedoitWolf __all__ = ['h_score', 'regularized_h_score'] def h_score(features: np.ndarray, labels: np.ndarray): r""" H-score in `An Information-theoretic Approach to Transferability in Task Transfer Learning (ICIP 2019) `_. The H-Score :math:`\mathcal{H}` can be described as: .. math:: \mathcal{H}=\operatorname{tr}\left(\operatorname{cov}(f)^{-1} \operatorname{cov}\left(\mathbb{E}[f \mid y]\right)\right) where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector Args: features (np.ndarray):features extracted by pre-trained model. labels (np.ndarray): groud-truth labels. Shape: - features: (N, F), with number of samples N and feature dimension F. - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`. - score: scalar. """ f = features y = labels covf = np.cov(f, rowvar=False) C = int(y.max() + 1) g = np.zeros_like(f) for i in range(C): Ef_i = np.mean(f[y == i, :], axis=0) g[y == i] = Ef_i covg = np.cov(g, rowvar=False) score = np.trace(np.dot(np.linalg.pinv(covf, rcond=1e-15), covg)) return score def regularized_h_score(features: np.ndarray, labels: np.ndarray): r""" Regularized H-score in `Newer is not always better: Rethinking transferability metrics, their peculiarities, stability and performance (NeurIPS 2021) `_. The regularized H-Score :math:`\mathcal{H}_{\alpha}` can be described as: .. math:: \mathcal{H}_{\alpha}=\operatorname{tr}\left(\operatorname{cov}_{\alpha}(f)^{-1}\left(1-\alpha \right)\operatorname{cov}\left(\mathbb{E}[f \mid y]\right)\right) where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector and :math:`\operatorname{cov}_{\alpha}` the Ledoit-Wolf covariance estimator with shrinkage parameter :math:`\alpha` Args: features (np.ndarray):features extracted by pre-trained model. labels (np.ndarray): groud-truth labels. Shape: - features: (N, F), with number of samples N and feature dimension F. - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`. - score: scalar. """ f = features.astype('float64') f = f - np.mean(f, axis=0, keepdims=True) # Center the features for correct Ledoit-Wolf Estimation y = labels C = int(y.max() + 1) g = np.zeros_like(f) cov = LedoitWolf(assume_centered=False).fit(f) alpha = cov.shrinkage_ covf_alpha = cov.covariance_ for i in range(C): Ef_i = np.mean(f[y == i, :], axis=0) g[y == i] = Ef_i covg = np.cov(g, rowvar=False) score = np.trace(np.dot(np.linalg.pinv(covf_alpha, rcond=1e-15), (1 - alpha) * covg)) return score ================================================ FILE: tllib/ranking/leep.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import numpy as np __all__ = ['log_expected_empirical_prediction'] def log_expected_empirical_prediction(predictions: np.ndarray, labels: np.ndarray): r""" Log Expected Empirical Prediction in `LEEP: A New Measure to Evaluate Transferability of Learned Representations (ICML 2020) `_. The LEEP :math:`\mathcal{T}` can be described as: .. math:: \mathcal{T}=\mathbb{E}\log \left(\sum_{z \in \mathcal{C}_s} \hat{P}\left(y \mid z\right) \theta\left(y \right)_{z}\right) where :math:`\theta\left(y\right)_{z}` is the predictions of pre-trained model on source category, :math:`\hat{P}\left(y \mid z\right)` is the empirical conditional distribution estimated by prediction and ground-truth label. Args: predictions (np.ndarray): predictions of pre-trained model. labels (np.ndarray): groud-truth labels. Shape: - predictions: (N, :math:`C_s`), with number of samples N and source class number :math:`C_s`. - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`. - score: scalar """ N, C_s = predictions.shape labels = labels.reshape(-1) C_t = int(np.max(labels) + 1) normalized_prob = predictions / float(N) joint = np.zeros((C_t, C_s), dtype=float) # placeholder for joint distribution over (y, z) for i in range(C_t): this_class = normalized_prob[labels == i] row = np.sum(this_class, axis=0) joint[i] = row p_target_given_source = (joint / joint.sum(axis=0, keepdims=True)).T # P(y | z) empirical_prediction = predictions @ p_target_given_source empirical_prob = np.array([predict[label] for predict, label in zip(empirical_prediction, labels)]) score = np.mean(np.log(empirical_prob)) return score ================================================ FILE: tllib/ranking/logme.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import numpy as np from numba import njit __all__ = ['log_maximum_evidence'] def log_maximum_evidence(features: np.ndarray, targets: np.ndarray, regression=False, return_weights=False): r""" Log Maximum Evidence in `LogME: Practical Assessment of Pre-trained Models for Transfer Learning (ICML 2021) `_. Args: features (np.ndarray): feature matrix from pre-trained model. targets (np.ndarray): targets labels/values. regression (bool, optional): whether to apply in regression setting. (Default: False) return_weights (bool, optional): whether to return bayesian weight. (Default: False) Shape: - features: (N, F) with element in [0, :math:`C_t`) and feature dimension F, where :math:`C_t` denotes the number of target class - targets: (N, ) or (N, C), with C regression-labels. - weights: (F, :math:`C_t`). - score: scalar. """ f = features.astype(np.float64) y = targets if regression: y = targets.astype(np.float64) fh = f f = f.transpose() D, N = f.shape v, s, vh = np.linalg.svd(f @ fh, full_matrices=True) evidences = [] weights = [] if regression: C = y.shape[1] for i in range(C): y_ = y[:, i] evidence, weight = each_evidence(y_, f, fh, v, s, vh, N, D) evidences.append(evidence) weights.append(weight) else: C = int(y.max() + 1) for i in range(C): y_ = (y == i).astype(np.float64) evidence, weight = each_evidence(y_, f, fh, v, s, vh, N, D) evidences.append(evidence) weights.append(weight) score = np.mean(evidences) weights = np.vstack(weights) if return_weights: return score, weights else: return score @njit def each_evidence(y_, f, fh, v, s, vh, N, D): """ compute the maximum evidence for each class """ alpha = 1.0 beta = 1.0 lam = alpha / beta tmp = (vh @ (f @ y_)) for _ in range(11): # should converge after at most 10 steps # typically converge after two or three steps gamma = (s / (s + lam)).sum() m = v @ (tmp * beta / (alpha + beta * s)) alpha_de = (m * m).sum() alpha = gamma / alpha_de beta_de = ((y_ - fh @ m) ** 2).sum() beta = (N - gamma) / beta_de new_lam = alpha / beta if np.abs(new_lam - lam) / lam < 0.01: break lam = new_lam evidence = D / 2.0 * np.log(alpha) \ + N / 2.0 * np.log(beta) \ - 0.5 * np.sum(np.log(alpha + beta * s)) \ - beta / 2.0 * beta_de \ - alpha / 2.0 * alpha_de \ - N / 2.0 * np.log(2 * np.pi) return evidence / N, m ================================================ FILE: tllib/ranking/nce.py ================================================ """ @author: Yong Liu @contact: liuyong1095556447@163.com """ import numpy as np __all__ = ['negative_conditional_entropy'] def negative_conditional_entropy(source_labels: np.ndarray, target_labels: np.ndarray): r""" Negative Conditional Entropy in `Transferability and Hardness of Supervised Classification Tasks (ICCV 2019) `_. The NCE :math:`\mathcal{H}` can be described as: .. math:: \mathcal{H}=-\sum_{y \in \mathcal{C}_t} \sum_{z \in \mathcal{C}_s} \hat{P}(y, z) \log \frac{\hat{P}(y, z)}{\hat{P}(z)} where :math:`\hat{P}(z)` is the empirical distribution and :math:`\hat{P}\left(y \mid z\right)` is the empirical conditional distribution estimated by source and target label. Args: source_labels (np.ndarray): predicted source labels. target_labels (np.ndarray): groud-truth target labels. Shape: - source_labels: (N, ) elements in [0, :math:`C_s`), with source class number :math:`C_s`. - target_labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`. """ C_t = int(np.max(target_labels) + 1) C_s = int(np.max(source_labels) + 1) N = len(source_labels) joint = np.zeros((C_t, C_s), dtype=float) # placeholder for the joint distribution, shape [C_t, C_s] for s, t in zip(source_labels, target_labels): s = int(s) t = int(t) joint[t, s] += 1.0 / N p_z = joint.sum(axis=0, keepdims=True) p_target_given_source = (joint / p_z).T # P(y | z), shape [C_s, C_t] mask = p_z.reshape(-1) != 0 # valid Z, shape [C_s] p_target_given_source = p_target_given_source[mask] + 1e-20 # remove NaN where p(z) = 0, add 1e-20 to avoid log (0) entropy_y_given_z = np.sum(- p_target_given_source * np.log(p_target_given_source), axis=1, keepdims=True) conditional_entropy = np.sum(entropy_y_given_z * p_z.reshape((-1, 1))[mask]) return -conditional_entropy ================================================ FILE: tllib/ranking/transrate.py ================================================ """ @author: Louis Fouquet @contact: louisfouquet75@gmail.com """ import numpy as np __all__ = ['transrate'] def coding_rate(features: np.ndarray, eps=1e-4): f = features n, d = f.shape (_, rate) = np.linalg.slogdet((np.eye(d) + 1 / (n * eps) * f.transpose() @ f)) return 0.5 * rate def transrate(features: np.ndarray, labels: np.ndarray, eps=1e-4): r""" TransRate in `Frustratingly easy transferability estimation (ICML 2022) `_. The TransRate :math:`TrR` can be described as: .. math:: TrR= R\left(f, \espilon \right) - R\left(f, \espilon \mid y \right) where :math:`f` is the features extracted by the model to be ranked, :math:`y` is the groud-truth label vector, :math:`R` is the coding rate with distortion rate :math:`\epsilon` Args: features (np.ndarray):features extracted by pre-trained model. labels (np.ndarray): groud-truth labels. eps (float, optional): distortion rare (Default: 1e-4) Shape: - features: (N, F), with number of samples N and feature dimension F. - labels: (N, ) elements in [0, :math:`C_t`), with target class number :math:`C_t`. - score: scalar. """ f = features y = labels f = f - np.mean(f, axis=0, keepdims=True) Rf = coding_rate(f, eps) Rfy = 0.0 C = int(y.max() + 1) for i in range(C): Rfy += coding_rate(f[(y == i).flatten()], eps) return Rf - Rfy / C ================================================ FILE: tllib/regularization/__init__.py ================================================ from .bss import * from .co_tuning import * from .delta import * from .bi_tuning import * from .knowledge_distillation import * __all__ = ['bss', 'co_tuning', 'delta', 'bi_tuning', 'knowledge_distillation'] ================================================ FILE: tllib/regularization/bi_tuning.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch import torch.nn as nn from torch.nn.functional import normalize from tllib.modules.classifier import Classifier as ClassifierBase class Classifier(ClassifierBase): """Classifier class for Bi-Tuning. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data num_classes (int): Number of classes projection_dim (int, optional): Dimension of the projector head. Default: 128 finetune (bool): Whether finetune the classifier or train from scratch. Default: True .. note:: The learning rate of this classifier is set 10 times to that of the feature extractor for better accuracy by default. If you have other optimization strategies, please over-ride :meth:`~Classifier.get_parameters`. Inputs: - x (tensor): input data fed to `backbone` Outputs: In the training mode, - y: classifier's predictions - z: projector's predictions - hn: normalized features after `bottleneck` layer and before `head` layer In the eval mode, - y: classifier's predictions Shape: - Inputs: (minibatch, *) where * means, any number of additional dimensions - y: (minibatch, `num_classes`) - z: (minibatch, `projection_dim`) - hn: (minibatch, `features_dim`) """ def __init__(self, backbone: nn.Module, num_classes: int, projection_dim=128, finetune=True, pool_layer=None): head = nn.Linear(backbone.out_features, num_classes) head.weight.data.normal_(0, 0.01) head.bias.data.fill_(0.0) super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune, pool_layer=pool_layer) self.projector = nn.Linear(backbone.out_features, projection_dim) self.projection_dim = projection_dim def forward(self, x: torch.Tensor): batch_size = x.shape[0] h = self.backbone(x) h = self.pool_layer(h) h = self.bottleneck(h) y = self.head(h) z = normalize(self.projector(h), dim=1) hn = torch.cat([h, torch.ones(batch_size, 1, dtype=torch.float).to(h.device)], dim=1) hn = normalize(hn, dim=1) if self.training: return y, z, hn else: return y def get_parameters(self, base_lr=1.0): """A parameter list which decides optimization hyper-parameters, such as the relative learning rate of each layer """ params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head.parameters(), "lr": 1.0 * base_lr}, {"params": self.projector.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, ] return params class BiTuning(nn.Module): """ Bi-Tuning Module in `Bi-tuning of Pre-trained Representations `_. Args: encoder_q (Classifier): Query encoder. encoder_k (Classifier): Key encoder. num_classes (int): Number of classes K (int): Queue size. Default: 40 m (float): Momentum coefficient. Default: 0.999 T (float): Temperature. Default: 0.07 Inputs: - im_q (tensor): input data fed to `encoder_q` - im_k (tensor): input data fed to `encoder_k` - labels (tensor): classification labels of input data Outputs: y_q, logits_z, logits_y, labels_c - y_q: query classifier's predictions - logits_z: projector's predictions on both positive and negative samples - logits_y: classifier's predictions on both positive and negative samples - labels_c: contrastive labels Shape: - im_q, im_k: (minibatch, *) where * means, any number of additional dimensions - labels: (minibatch, ) - y_q: (minibatch, `num_classes`) - logits_z: (minibatch, 1 + `num_classes` x `K`, `projection_dim`) - logits_y: (minibatch, 1 + `num_classes` x `K`, `num_classes`) - labels_c: (minibatch, 1 + `num_classes` x `K`) """ def __init__(self, encoder_q: Classifier, encoder_k: Classifier, num_classes, K=40, m=0.999, T=0.07): super(BiTuning, self).__init__() self.K = K self.m = m self.T = T self.num_classes = num_classes # create the encoders # num_classes is the output fc dimension self.encoder_q = encoder_q self.encoder_k = encoder_k for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data.copy_(param_q.data) # initialize param_k.requires_grad = False # not update by gradient # create the queue self.register_buffer("queue_h", torch.randn(encoder_q.features_dim + 1, num_classes, K)) self.register_buffer("queue_z", torch.randn(encoder_q.projection_dim, num_classes, K)) self.queue_h = normalize(self.queue_h, dim=0) self.queue_z = normalize(self.queue_z, dim=0) self.register_buffer("queue_ptr", torch.zeros(num_classes, dtype=torch.long)) @torch.no_grad() def _momentum_update_key_encoder(self): """ Momentum update of the key encoder """ for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) @torch.no_grad() def _dequeue_and_enqueue(self, h, z, label): batch_size = h.shape[0] assert self.K % batch_size == 0 # for simplicity ptr = int(self.queue_ptr[label]) # replace the keys at ptr (dequeue and enqueue) self.queue_h[:, label, ptr: ptr + batch_size] = h.T self.queue_z[:, label, ptr: ptr + batch_size] = z.T # move pointer self.queue_ptr[label] = (ptr + batch_size) % self.K def forward(self, im_q, im_k, labels): batch_size = im_q.size(0) device = im_q.device # compute query features y_q, z_q, h_q = self.encoder_q(im_q) # compute key features with torch.no_grad(): # no gradient to keys self._momentum_update_key_encoder() # update the key encoder y_k, z_k, h_k = self.encoder_k(im_k) # compute logits for projection z # current positive logits: Nx1 logits_z_cur = torch.einsum('nc,nc->n', [z_q, z_k]).unsqueeze(-1) queue_z = self.queue_z.clone().detach().to(device) # positive logits: N x K logits_z_pos = torch.Tensor([]).to(device) # negative logits: N x ((C-1) x K) logits_z_neg = torch.Tensor([]).to(device) for i in range(batch_size): c = labels[i] pos_samples = queue_z[:, c, :] # D x K neg_samples = torch.cat([queue_z[:, 0: c, :], queue_z[:, c + 1:, :]], dim=1).flatten( start_dim=1) # D x ((C-1)xK) ith_pos = torch.einsum('nc,ck->nk', [z_q[i: i + 1], pos_samples]) # 1 x D ith_neg = torch.einsum('nc,ck->nk', [z_q[i: i + 1], neg_samples]) # 1 x ((C-1)xK) logits_z_pos = torch.cat((logits_z_pos, ith_pos), dim=0) logits_z_neg = torch.cat((logits_z_neg, ith_neg), dim=0) self._dequeue_and_enqueue(h_k[i:i + 1], z_k[i:i + 1], labels[i]) logits_z = torch.cat([logits_z_cur, logits_z_pos, logits_z_neg], dim=1) # Nx(1+C*K) # apply temperature logits_z /= self.T logits_z = nn.LogSoftmax(dim=1)(logits_z) # compute logits for classification y w = torch.cat([self.encoder_q.head.weight.data, self.encoder_q.head.bias.data.unsqueeze(-1)], dim=1) w = normalize(w, dim=1) # C x F # current positive logits: Nx1 logits_y_cur = torch.einsum('nk,kc->nc', [h_q, w.T]) # N x C queue_y = self.queue_h.clone().detach().to(device).flatten(start_dim=1).T # (C * K) x F logits_y_queue = torch.einsum('nk,kc->nc', [queue_y, w.T]).reshape(self.num_classes, -1, self.num_classes) # C x K x C logits_y = torch.Tensor([]).to(device) for i in range(batch_size): c = labels[i] # calculate the ith sample in the batch cur_sample = logits_y_cur[i:i + 1, c] # 1 pos_samples = logits_y_queue[c, :, c] # K neg_samples = torch.cat([logits_y_queue[0: c, :, c], logits_y_queue[c + 1:, :, c]], dim=0).view( -1) # (C-1)*K ith = torch.cat([cur_sample, pos_samples, neg_samples]) # 1+C*K logits_y = torch.cat([logits_y, ith.unsqueeze(dim=0)], dim=0) logits_y /= self.T logits_y = nn.LogSoftmax(dim=1)(logits_y) # contrastive labels labels_c = torch.zeros([batch_size, self.K * self.num_classes + 1]).to(device) labels_c[:, 0:self.K + 1].fill_(1.0 / (self.K + 1)) return y_q, logits_z, logits_y, labels_c ================================================ FILE: tllib/regularization/bss.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import torch import torch.nn as nn __all__ = ['BatchSpectralShrinkage'] class BatchSpectralShrinkage(nn.Module): r""" The regularization term in `Catastrophic Forgetting Meets Negative Transfer: Batch Spectral Shrinkage for Safe Transfer Learning (NIPS 2019) `_. The BSS regularization of feature matrix :math:`F` can be described as: .. math:: L_{bss}(F) = \sum_{i=1}^{k} \sigma_{-i}^2 , where :math:`k` is the number of singular values to be penalized, :math:`\sigma_{-i}` is the :math:`i`-th smallest singular value of feature matrix :math:`F`. All the singular values of feature matrix :math:`F` are computed by `SVD`: .. math:: F = U\Sigma V^T, where the main diagonal elements of the singular value matrix :math:`\Sigma` is :math:`[\sigma_1, \sigma_2, ..., \sigma_b]`. Args: k (int): The number of singular values to be penalized. Default: 1 Shape: - Input: :math:`(b, |\mathcal{f}|)` where :math:`b` is the batch size and :math:`|\mathcal{f}|` is feature dimension. - Output: scalar. """ def __init__(self, k=1): super(BatchSpectralShrinkage, self).__init__() self.k = k def forward(self, feature): result = 0 u, s, v = torch.svd(feature.t()) num = s.size(0) for i in range(self.k): result += torch.pow(s[num-1-i], 2) return result ================================================ FILE: tllib/regularization/co_tuning.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ from typing import Tuple, Optional, List, Dict import os import torch import torch.nn as nn import numpy as np import torch.nn.functional as F import tqdm from .lwf import Classifier as ClassifierBase __all__ = ['Classifier', 'CoTuningLoss', 'Relationship'] class CoTuningLoss(nn.Module): """ The Co-Tuning loss in `Co-Tuning for Transfer Learning (NIPS 2020) `_. Inputs: - input: p(y_s) predicted by source classifier. - target: p(y_s|y_t), where y_t is the ground truth class label in target dataset. Shape: - input: (b, N_p), where b is the batch size and N_p is the number of classes in source dataset - target: (b, N_p), where b is the batch size and N_p is the number of classes in source dataset - Outputs: scalar. """ def __init__(self): super(CoTuningLoss, self).__init__() def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: y = - target * F.log_softmax(input, dim=-1) y = torch.mean(torch.sum(y, dim=-1)) return y class Relationship(object): """Learns the category relationship p(y_s|y_t) between source dataset and target dataset. Args: data_loader (torch.utils.data.DataLoader): A data loader of target dataset. classifier (torch.nn.Module): A classifier for Co-Tuning. device (torch.nn.Module): The device to run classifier. cache (str, optional): Path to find and save the relationship file. """ def __init__(self, data_loader, classifier, device, cache=None): super(Relationship, self).__init__() self.data_loader = data_loader self.classifier = classifier self.device = device if cache is None or not os.path.exists(cache): source_predictions, target_labels = self.collect_labels() self.relationship = self.get_category_relationship(source_predictions, target_labels) if cache is not None: np.save(cache, self.relationship) else: self.relationship = np.load(cache) def __getitem__(self, category): return self.relationship[category] def collect_labels(self): """ Collects predictions of target dataset by source model and corresponding ground truth class labels. Returns: - source_probabilities, [N, N_p], where N_p is the number of classes in source dataset - target_labels, [N], where 0 <= each number < N_t, and N_t is the number of classes in target dataset """ print("Collecting labels to calculate relationship") source_predictions = [] target_labels = [] self.classifier.eval() with torch.no_grad(): for i, (x, label) in enumerate(tqdm.tqdm(self.data_loader)): x = x.to(self.device) y_s = self.classifier(x) source_predictions.append(F.softmax(y_s, dim=1).detach().cpu().numpy()) target_labels.append(label) return np.concatenate(source_predictions, 0), np.concatenate(target_labels, 0) def get_category_relationship(self, source_probabilities, target_labels): """ The direct approach of learning category relationship p(y_s | y_t). Args: source_probabilities (numpy.array): [N, N_p], where N_p is the number of classes in source dataset target_labels (numpy.array): [N], where 0 <= each number < N_t, and N_t is the number of classes in target dataset Returns: Conditional probability, [N_c, N_p] matrix representing the conditional probability p(pre-trained class | target_class) """ N_t = np.max(target_labels) + 1 # the number of target classes conditional = [] for i in range(N_t): this_class = source_probabilities[target_labels == i] average = np.mean(this_class, axis=0, keepdims=True) conditional.append(average) return np.concatenate(conditional) class Classifier(ClassifierBase): """A Classifier used in `Co-Tuning for Transfer Learning (NIPS 2020) `_.. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data. num_classes (int): Number of classes. head_source (torch.nn.Module): Classifier head of source model. head_target (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default finetune (bool): Whether finetune the classifier or train from scratch. Default: True Inputs: - x (tensor): input data fed to backbone Outputs: - y_s: predictions of source classifier head - y_t: predictions of target classifier head Shape: - Inputs: (b, *) where b is the batch size and * means any number of additional dimensions - y_s: (b, N), where b is the batch size and N is the number of classes - y_t: (b, N), where b is the batch size and N is the number of classes """ def __init__(self, backbone: nn.Module, num_classes: int, head_source, **kwargs): super(Classifier, self).__init__(backbone, num_classes, head_source, **kwargs) def get_parameters(self, base_lr=1.0) -> List[Dict]: """A parameter list which decides optimization hyper-parameters, such as the relative learning rate of each layer """ params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.head_source.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head_target.parameters(), "lr": 1.0 * base_lr}, ] return params ================================================ FILE: tllib/regularization/delta.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import torch import torch.nn as nn import functools from collections import OrderedDict class L2Regularization(nn.Module): r"""The L2 regularization of parameters :math:`w` can be described as: .. math:: {\Omega} (w) = \dfrac{1}{2} \Vert w\Vert_2^2 , Args: model (torch.nn.Module): The model to apply L2 penalty. Shape: - Output: scalar. """ def __init__(self, model: nn.Module): super(L2Regularization, self).__init__() self.model = model def forward(self): output = 0.0 for param in self.model.parameters(): output += 0.5 * torch.norm(param) ** 2 return output class SPRegularization(nn.Module): r""" The SP (Starting Point) regularization from `Explicit inductive bias for transfer learning with convolutional networks (ICML 2018) `_ The SP regularization of parameters :math:`w` can be described as: .. math:: {\Omega} (w) = \dfrac{1}{2} \Vert w-w_0\Vert_2^2 , where :math:`w_0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning. Args: source_model (torch.nn.Module): The source (starting point) model. target_model (torch.nn.Module): The target (fine-tuning) model. Shape: - Output: scalar. """ def __init__(self, source_model: nn.Module, target_model: nn.Module): super(SPRegularization, self).__init__() self.target_model = target_model self.source_weight = {} for name, param in source_model.named_parameters(): self.source_weight[name] = param.detach() def forward(self): output = 0.0 for name, param in self.target_model.named_parameters(): output += 0.5 * torch.norm(param - self.source_weight[name]) ** 2 return output class BehavioralRegularization(nn.Module): r""" The behavioral regularization from `DELTA:DEep Learning Transfer using Feature Map with Attention for convolutional networks (ICLR 2019) `_ It can be described as: .. math:: {\Omega} (w) = \sum_{j=1}^{N} \Vert FM_j(w, \boldsymbol x)-FM_j(w^0, \boldsymbol x)\Vert_2^2 , where :math:`w^0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning, :math:`FM_j(w, \boldsymbol x)` is feature maps generated from the :math:`j`-th layer of the model parameterized with :math:`w`, given the input :math:`\boldsymbol x`. Inputs: layer_outputs_source (OrderedDict): The dictionary for source model, where the keys are layer names and the values are feature maps correspondingly. layer_outputs_target (OrderedDict): The dictionary for target model, where the keys are layer names and the values are feature maps correspondingly. Shape: - Output: scalar. """ def __init__(self): super(BehavioralRegularization, self).__init__() def forward(self, layer_outputs_source, layer_outputs_target): output = 0.0 for fm_src, fm_tgt in zip(layer_outputs_source.values(), layer_outputs_target.values()): output += 0.5 * (torch.norm(fm_tgt - fm_src.detach()) ** 2) return output class AttentionBehavioralRegularization(nn.Module): r""" The behavioral regularization with attention from `DELTA:DEep Learning Transfer using Feature Map with Attention for convolutional networks (ICLR 2019) `_ It can be described as: .. math:: {\Omega} (w) = \sum_{j=1}^{N} W_j(w) \Vert FM_j(w, \boldsymbol x)-FM_j(w^0, \boldsymbol x)\Vert_2^2 , where :math:`w^0` is the parameter vector of the model pretrained on the source problem, acting as the starting point (SP) in fine-tuning. :math:`FM_j(w, \boldsymbol x)` is feature maps generated from the :math:`j`-th layer of the model parameterized with :math:`w`, given the input :math:`\boldsymbol x`. :math:`W_j(w)` is the channel attention of the :math:`j`-th layer of the model parameterized with :math:`w`. Args: channel_attention (list): The channel attentions of feature maps generated by each selected layer. For the layer with C channels, the channel attention is a tensor of shape [C]. Inputs: layer_outputs_source (OrderedDict): The dictionary for source model, where the keys are layer names and the values are feature maps correspondingly. layer_outputs_target (OrderedDict): The dictionary for target model, where the keys are layer names and the values are feature maps correspondingly. Shape: - Output: scalar. """ def __init__(self, channel_attention): super(AttentionBehavioralRegularization, self).__init__() self.channel_attention = channel_attention def forward(self, layer_outputs_source, layer_outputs_target): output = 0.0 for i, (fm_src, fm_tgt) in enumerate(zip(layer_outputs_source.values(), layer_outputs_target.values())): b, c, h, w = fm_src.shape fm_src = fm_src.reshape(b, c, h * w) fm_tgt = fm_tgt.reshape(b, c, h * w) distance = torch.norm(fm_tgt - fm_src.detach(), 2, 2) distance = c * torch.mul(self.channel_attention[i], distance ** 2) / (h * w) output += 0.5 * torch.sum(distance) return output def get_attribute(obj, attr, *args): def _getattr(obj, attr): return getattr(obj, attr, *args) return functools.reduce(_getattr, [obj] + attr.split('.')) class IntermediateLayerGetter: r""" Wraps a model to get intermediate output values of selected layers. Args: model (torch.nn.Module): The model to collect intermediate layer feature maps. return_layers (list): The names of selected modules to return the output. keep_output (bool): If True, `model_output` contains the final model's output, else return None. Default: True Returns: - An OrderedDict of intermediate outputs. The keys are selected layer names in `return_layers` and the values are the feature map outputs. The order is the same as `return_layers`. - The model's final output. If `keep_output` is False, return None. """ def __init__(self, model, return_layers, keep_output=True): self._model = model self.return_layers = return_layers self.keep_output = keep_output def __call__(self, *args, **kwargs): ret = OrderedDict() handles = [] for name in self.return_layers: layer = get_attribute(self._model, name) def hook(module, input, output, name=name): ret[name] = output try: h = layer.register_forward_hook(hook) except AttributeError as e: raise AttributeError(f'Module {name} not found') handles.append(h) if self.keep_output: output = self._model(*args, **kwargs) else: self._model(*args, **kwargs) output = None for h in handles: h.remove() return ret, output ================================================ FILE: tllib/regularization/knowledge_distillation.py ================================================ import torch.nn as nn import torch.nn.functional as F class KnowledgeDistillationLoss(nn.Module): """Knowledge Distillation Loss. Args: T (double): Temperature. Default: 1. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'batchmean'`` Inputs: - y_student (tensor): logits output of the student - y_teacher (tensor): logits output of the teacher Shape: - y_student: (minibatch, `num_classes`) - y_teacher: (minibatch, `num_classes`) """ def __init__(self, T=1., reduction='batchmean'): super(KnowledgeDistillationLoss, self).__init__() self.T = T self.kl = nn.KLDivLoss(reduction=reduction) def forward(self, y_student, y_teacher): """""" return self.kl(F.log_softmax(y_student / self.T, dim=-1), F.softmax(y_teacher / self.T, dim=-1)) ================================================ FILE: tllib/regularization/lwf.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, List, Dict import torch import torch.nn as nn import tqdm def collect_pretrain_labels(data_loader, classifier, device): source_predictions = [] classifier.eval() with torch.no_grad(): for i, (x, label) in enumerate(tqdm.tqdm(data_loader)): x = x.to(device) y_s = classifier(x) source_predictions.append(y_s.detach().cpu()) return torch.cat(source_predictions, dim=0) class Classifier(nn.Module): """A Classifier used in `Learning Without Forgetting (ECCV 2016) `_.. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data. num_classes (int): Number of classes. head_source (torch.nn.Module): Classifier head of source model. head_target (torch.nn.Module, optional): Any classifier head. Use :class:`torch.nn.Linear` by default finetune (bool): Whether finetune the classifier or train from scratch. Default: True Inputs: - x (tensor): input data fed to backbone Outputs: - y_s: predictions of source classifier head - y_t: predictions of target classifier head Shape: - Inputs: (b, *) where b is the batch size and * means any number of additional dimensions - y_s: (b, N), where b is the batch size and N is the number of classes - y_t: (b, N), where b is the batch size and N is the number of classes """ def __init__(self, backbone: nn.Module, num_classes: int, head_source, head_target: Optional[nn.Module] = None, bottleneck: Optional[nn.Module] = None, bottleneck_dim: Optional[int] = -1, finetune=True, pool_layer=None): super(Classifier, self).__init__() self.backbone = backbone self.num_classes = num_classes if pool_layer is None: self.pool_layer = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() ) else: self.pool_layer = pool_layer if bottleneck is None: self.bottleneck = nn.Identity() self._features_dim = backbone.out_features else: self.bottleneck = bottleneck assert bottleneck_dim > 0 self._features_dim = bottleneck_dim self.head_source = head_source if head_target is None: self.head_target = nn.Linear(self._features_dim, num_classes) else: self.head_target = head_target self.finetune = finetune @property def features_dim(self) -> int: """The dimension of features before the final `head` layer""" return self._features_dim def forward(self, x: torch.Tensor): """""" f = self.backbone(x) f = self.pool_layer(f) y_s = self.head_source(f) y_t = self.head_target(self.bottleneck(f)) if self.training: return y_s, y_t else: return y_t def get_parameters(self, base_lr=1.0) -> List[Dict]: """A parameter list which decides optimization hyper-parameters, such as the relative learning rate of each layer """ params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, # {"params": self.head_source.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head_target.parameters(), "lr": 1.0 * base_lr}, ] return params ================================================ FILE: tllib/reweight/__init__.py ================================================ ================================================ FILE: tllib/reweight/groupdro.py ================================================ """ Modified from https://github.com/facebookresearch/DomainBed @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch class AutomaticUpdateDomainWeightModule(object): r""" Maintaining group weight based on loss history of all domains according to `Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization (ICLR 2020) `_. Suppose we have :math:`N` domains. During each iteration, we first calculate unweighted loss among all domains, resulting in :math:`loss\in R^N`. Then we update domain weight by .. math:: w_k = w_k * \text{exp}(loss_k ^{\eta}), \forall k \in [1, N] where :math:`\eta` is the hyper parameter which ensures smoother change of weight. As :math:`w \in R^N` denotes a distribution, we `normalize` :math:`w` by its sum. At last, weighted loss is calculated as our objective .. math:: objective = \sum_{k=1}^N w_k * loss_k Args: num_domains (int): The number of source domains. eta (float): Hyper parameter eta. device (torch.device): The device to run on. """ def __init__(self, num_domains: int, eta: float, device): self.domain_weight = torch.ones(num_domains).to(device) / num_domains self.eta = eta def get_domain_weight(self, sampled_domain_idxes): """Get domain weight to calculate final objective. Inputs: - sampled_domain_idxes (list): sampled domain indexes in current mini-batch Shape: - sampled_domain_idxes: :math:`(D, )` where D means the number of sampled domains in current mini-batch - Outputs: :math:`(D, )` """ domain_weight = self.domain_weight[sampled_domain_idxes] domain_weight = domain_weight / domain_weight.sum() return domain_weight def update(self, sampled_domain_losses: torch.Tensor, sampled_domain_idxes): """Update domain weight using loss of current mini-batch. Inputs: - sampled_domain_losses (tensor): loss of among sampled domains in current mini-batch - sampled_domain_idxes (list): sampled domain indexes in current mini-batch Shape: - sampled_domain_losses: :math:`(D, )` where D means the number of sampled domains in current mini-batch - sampled_domain_idxes: :math:`(D, )` """ sampled_domain_losses = sampled_domain_losses.detach() for loss, idx in zip(sampled_domain_losses, sampled_domain_idxes): self.domain_weight[idx] *= (self.eta * loss).exp() ================================================ FILE: tllib/reweight/iwan.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from typing import Optional, List, Dict import torch import torch.nn as nn from tllib.modules.classifier import Classifier as ClassifierBase class ImportanceWeightModule(object): r""" Calculating class weight based on the output of discriminator. Introduced by `Importance Weighted Adversarial Nets for Partial Domain Adaptation (CVPR 2018) `_ Args: discriminator (torch.nn.Module): A domain discriminator object, which predicts the domains of features. Its input shape is :math:`(N, F)` and output shape is :math:`(N, 1)` partial_classes_index (list[int], optional): The index of partial classes. Note that this parameter is \ just for debugging, since in real-world dataset, we have no access to the index of partial classes. \ Default: None. Examples:: >>> domain_discriminator = DomainDiscriminator(1024, 1024) >>> importance_weight_module = ImportanceWeightModule(domain_discriminator) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> # feature from source domain >>> f_s = torch.randn(32, 1024) >>> # importance weights for source instance >>> w_s = importance_weight_module.get_importance_weight(f_s) """ def __init__(self, discriminator: nn.Module, partial_classes_index: Optional[List[int]] = None): self.discriminator = discriminator self.partial_classes_index = partial_classes_index def get_importance_weight(self, feature): """ Get importance weights for each instance. Args: feature (tensor): feature from source domain, in shape :math:`(N, F)` Returns: instance weight in shape :math:`(N, 1)` """ weight = 1. - self.discriminator(feature) weight = weight / (weight.mean() + 1e-5) weight = weight.detach() return weight def get_partial_classes_weight(self, weights: torch.Tensor, labels: torch.Tensor): """ Get class weight averaged on the partial classes and non-partial classes respectively. Args: weights (tensor): instance weight in shape :math:`(N, 1)` labels (tensor): ground truth labels in shape :math:`(N, 1)` .. warning:: This function is just for debugging, since in real-world dataset, we have no access to the index of \ partial classes and this function will throw an error when `partial_classes_index` is None. """ assert self.partial_classes_index is not None weights = weights.squeeze() is_partial = torch.Tensor([label in self.partial_classes_index for label in labels]).to(weights.device) if is_partial.sum() > 0: partial_classes_weight = (weights * is_partial).sum() / is_partial.sum() else: partial_classes_weight = torch.tensor(0) not_partial = 1. - is_partial if not_partial.sum() > 0: not_partial_classes_weight = (weights * not_partial).sum() / not_partial.sum() else: not_partial_classes_weight = torch.tensor(0) return partial_classes_weight, not_partial_classes_weight class ImageClassifier(ClassifierBase): r"""The Image Classifier for `Importance Weighted Adversarial Nets for Partial Domain Adaptation `_ """ def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) ================================================ FILE: tllib/reweight/pada.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, List, Tuple from torch.utils.data.dataloader import DataLoader import torch.nn as nn import torch import torch.nn.functional as F class AutomaticUpdateClassWeightModule(object): r""" Calculating class weight based on the output of classifier. See ``ClassWeightModule`` about the details of the calculation. Every N iterations, the class weight is updated automatically. Args: update_steps (int): N, the number of iterations to update class weight. data_loader (torch.utils.data.DataLoader): The data loader from which we can collect classification outputs. classifier (torch.nn.Module): Classifier. num_classes (int): Number of classes. device (torch.device): The device to run classifier. temperature (float, optional): T, temperature in ClassWeightModule. Default: 0.1 partial_classes_index (list[int], optional): The index of partial classes. Note that this parameter is \ just for debugging, since in real-world dataset, we have no access to the index of partial classes. \ Default: None. Examples:: >>> class_weight_module = AutomaticUpdateClassWeightModule(update_steps=500, ...) >>> num_iterations = 10000 >>> for _ in range(num_iterations): >>> class_weight_module.step() >>> # weight for F.cross_entropy >>> w_c = class_weight_module.get_class_weight_for_cross_entropy_loss() >>> # weight for tllib.alignment.dann.DomainAdversarialLoss >>> w_s, w_t = class_weight_module.get_class_weight_for_adversarial_loss() """ def __init__(self, update_steps: int, data_loader: DataLoader, classifier: nn.Module, num_classes: int, device: torch.device, temperature: Optional[float] = 0.1, partial_classes_index: Optional[List[int]] = None): self.update_steps = update_steps self.data_loader = data_loader self.classifier = classifier self.device = device self.class_weight_module = ClassWeightModule(temperature) self.class_weight = torch.ones(num_classes).to(device) self.num_steps = 0 self.partial_classes_index = partial_classes_index if partial_classes_index is not None: self.non_partial_classes_index = [c for c in range(num_classes) if c not in partial_classes_index] def step(self): self.num_steps += 1 if self.num_steps % self.update_steps == 0: all_outputs = collect_classification_results(self.data_loader, self.classifier, self.device) self.class_weight = self.class_weight_module(all_outputs) def get_class_weight_for_cross_entropy_loss(self): """ Outputs: weight for F.cross_entropy Shape: :math:`(C, )` where C means the number of classes. """ return self.class_weight def get_class_weight_for_adversarial_loss(self, source_labels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Outputs: - w_s: source weight for :py:class:`~tllib.alignment.dann.DomainAdversarialLoss` - w_t: target weight for :py:class:`~tllib.alignment.dann.DomainAdversarialLoss` Shape: - w_s: :math:`(minibatch, )` - w_t: :math:`(minibatch, )` """ class_weight_adv_source = self.class_weight[source_labels] class_weight_adv_target = torch.ones_like(class_weight_adv_source) * class_weight_adv_source.mean() return class_weight_adv_source, class_weight_adv_target def get_partial_classes_weight(self): """ Get class weight averaged on the partial classes and non-partial classes respectively. .. warning:: This function is just for debugging, since in real-world dataset, we have no access to the index of \ partial classes and this function will throw an error when `partial_classes_index` is None. """ assert self.partial_classes_index is not None return torch.mean(self.class_weight[self.partial_classes_index]), torch.mean( self.class_weight[self.non_partial_classes_index]) class ClassWeightModule(nn.Module): r""" Calculating class weight based on the output of classifier. Introduced by `Partial Adversarial Domain Adaptation (ECCV 2018) `_ Given classification logits outputs :math:`\{\hat{y}_i\}_{i=1}^n`, where :math:`n` is the dataset size, the weight indicating the contribution of each class to the training can be calculated as follows .. math:: \mathcal{\gamma} = \dfrac{1}{n} \sum_{i=1}^{n}\text{softmax}( \hat{y}_i / T), where :math:`\mathcal{\gamma}` is a :math:`|\mathcal{C}|`-dimensional weight vector quantifying the contribution of each class and T is a hyper-parameters called temperature. In practice, it's possible that some of the weights are very small, thus, we normalize weight :math:`\mathcal{\gamma}` by dividing its largest element, i.e. :math:`\mathcal{\gamma} \leftarrow \mathcal{\gamma} / max(\mathcal{\gamma})` Args: temperature (float, optional): hyper-parameters :math:`T`. Default: 0.1 Shape: - Inputs: (minibatch, :math:`|\mathcal{C}|`) - Outputs: (:math:`|\mathcal{C}|`,) """ def __init__(self, temperature: Optional[float] = 0.1): super(ClassWeightModule, self).__init__() self.temperature = temperature def forward(self, outputs: torch.Tensor): outputs.detach_() softmax_outputs = F.softmax(outputs / self.temperature, dim=1) class_weight = torch.mean(softmax_outputs, dim=0) class_weight = class_weight / torch.max(class_weight) class_weight = class_weight.view(-1) return class_weight def collect_classification_results(data_loader: DataLoader, classifier: nn.Module, device: torch.device) -> torch.Tensor: """ Fetch data from `data_loader`, and then use `classifier` to collect classification results Args: data_loader (torch.utils.data.DataLoader): Data loader. classifier (torch.nn.Module): A classifier. device (torch.device) Returns: Classification results in shape (len(data_loader), :math:`|\mathcal{C}|`). """ training = classifier.training classifier.eval() all_outputs = [] with torch.no_grad(): for i, (images, target) in enumerate(data_loader): images = images.to(device) output = classifier(images) all_outputs.append(output) classifier.train(training) return torch.cat(all_outputs, dim=0) ================================================ FILE: tllib/self_training/__init__.py ================================================ ================================================ FILE: tllib/self_training/cc_loss.py ================================================ """ @author: Ying Jin @contact: sherryying003@gmail.com """ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from tllib.modules.classifier import Classifier as ClassifierBase from ..modules.entropy import entropy __all__ = ['CCConsistency'] class CCConsistency(nn.Module): r""" CC Loss attach class confusion consistency to MCC. Args: temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if temperature is 1.0. thr (float): The confidence threshold. .. note:: Make sure that temperature is larger than 0. Confidence threshold is larger than 0, smaller than 1.0. Inputs: g_t - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t` - g_t_strong (tensor): unnormalized classifier predictions on target domain, with strong data augmentation, :math:`g^t_{strong}` Shape: - g_t, g_t_strong: :math:`(minibatch, C)` where C means the number of classes. - Output: scalar. Examples:: >>> temperature = 2.0 >>> loss = CCConsistency(temperature) >>> # logits output from target domain >>> g_t = torch.randn(batch_size, num_classes) >>> g_t_strong = torch.randn(batch_size, num_classes) >>> output = loss(g_t, g_t_strong) """ def __init__(self, temperature: float, thr=0.7): super(CCConsistency, self).__init__() self.temperature = temperature self.thr = thr def forward(self, logits: torch.Tensor, logits_strong: torch.Tensor) -> torch.Tensor: batch_size, num_classes = logits.shape logits = logits.detach() prediction_thr = F.softmax(logits / self.temperature, dim=1) max_probs, max_idx = torch.max(prediction_thr, dim=-1) mask_binary = max_probs.ge(self.thr) ### 0.7 for DomainNet, 0.95 for other datasets mask = mask_binary.float().detach() if mask.sum() == 0: return 0, 0 else: logits = logits[mask_binary] logits_strong = logits_strong[mask_binary] predictions = F.softmax(logits / self.temperature, dim=1) # batch_size x num_classes entropy_weight = entropy(predictions).detach() entropy_weight = 1 + torch.exp(-entropy_weight) entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1) # batch_size x 1 class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1) predictions_stong = F.softmax(logits_strong / self.temperature, dim=1) entropy_weight_strong = entropy(predictions_stong).detach() entropy_weight_strong = 1 + torch.exp(-entropy_weight_strong) entropy_weight_strong = (batch_size * entropy_weight_strong / torch.sum(entropy_weight_strong)).unsqueeze(dim=1) # batch_size x 1 class_confusion_matrix_strong = torch.mm((predictions_stong * entropy_weight_strong).transpose(1, 0), predictions_stong) # num_classes x num_classes class_confusion_matrix_strong = class_confusion_matrix_strong / torch.sum(class_confusion_matrix_strong, dim=1) consistency_loss = ((class_confusion_matrix - class_confusion_matrix_strong) ** 2).sum() / num_classes * mask.sum() / batch_size #mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes return consistency_loss, mask.sum()/batch_size ================================================ FILE: tllib/self_training/dst.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch import torch.nn as nn import torch.nn.functional as F from tllib.modules.grl import WarmStartGradientReverseLayer from tllib.modules.classifier import Classifier class ImageClassifier(Classifier): r""" Classifier with non-linear pseudo head :math:`h_{\text{pseudo}}` and worst-case estimation head :math:`h_{\text{worst}}` from `Debiased Self-Training for Semi-Supervised Learning `_. Both heads are directly connected to the feature extractor :math:`\psi`. We implement end-to-end adversarial training procedure between :math:`\psi` and :math:`h_{\text{worst}}` by introducing a gradient reverse layer. Note that both heads can be safely discarded during inference, and thus will introduce no inference cost. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data num_classes (int): Number of classes bottleneck_dim (int, optional): Feature dimension of the bottleneck layer. width (int, optional): Hidden dimension of the non-linear pseudo head and worst-case estimation head. Inputs: - x (tensor): input data fed to `backbone` Outputs: - outputs: predictions of the main head :math:`h` - outputs_adv: predictions of the worst-case estimation head :math:`h_{\text{worst}}` - outputs_pseudo: predictions of the pseudo head :math:`h_{\text{pseudo}}` Shape: - Inputs: (minibatch, *) where * means, any number of additional dimensions - outputs, outputs_adv, outputs_pseudo: (minibatch, `num_classes`) """ def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, width=2048, **kwargs): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) bottleneck[0].weight.data.normal_(0, 0.005) bottleneck[0].bias.data.fill_(0.1) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) self.pseudo_head = nn.Sequential( nn.Linear(self.features_dim, width), nn.ReLU(), nn.Dropout(0.5), nn.Linear(width, self.num_classes) ) self.grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000, auto_step=False) self.adv_head = nn.Sequential( nn.Linear(self.features_dim, width), nn.ReLU(), nn.Dropout(0.5), nn.Linear(width, self.num_classes) ) def forward(self, x: torch.Tensor): f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) f_adv = self.grl_layer(f) outputs_adv = self.adv_head(f_adv) outputs = self.head(f) outputs_pseudo = self.pseudo_head(f) if self.training: return outputs, outputs_adv, outputs_pseudo else: return outputs def get_parameters(self, base_lr=1.0): """A parameter list which decides optimization hyper-parameters, such as the relative learning rate of each layer """ params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head.parameters(), "lr": 1.0 * base_lr}, {"params": self.pseudo_head.parameters(), "lr": 1.0 * base_lr}, {"params": self.adv_head.parameters(), "lr": 1.0 * base_lr} ] return params def step(self): self.grl_layer.step() def shift_log(x, offset=1e-6): """ First shift, then calculate log for numerical stability. """ return torch.log(torch.clamp(x + offset, max=1.)) class WorstCaseEstimationLoss(nn.Module): r""" Worst-case Estimation loss from `Debiased Self-Training for Semi-Supervised Learning `_ that forces the worst possible head :math:`h_{\text{worst}}` to predict correctly on all labeled samples :math:`\mathcal{L}` while making as many mistakes as possible on unlabeled data :math:`\mathcal{U}`. In the classification task, it is defined as: .. math:: loss(\mathcal{L}, \mathcal{U}) = \eta' \mathbb{E}_{y^l, y_{adv}^l \sim\hat{\mathcal{L}}} -\log\left(\frac{\exp(y_{adv}^l[h_{y^l}])}{\sum_j \exp(y_{adv}^l[j])}\right) + \mathbb{E}_{y^u, y_{adv}^u \sim\hat{\mathcal{U}}} -\log\left(1-\frac{\exp(y_{adv}^u[h_{y^u}])}{\sum_j \exp(y_{adv}^u[j])}\right), where :math:`y^l` and :math:`y^u` are logits output by the main head :math:`h` on labeled data and unlabeled data, respectively. :math:`y_{adv}^l` and :math:`y_{adv}^u` are logits output by the worst-case estimation head :math:`h_{\text{worst}}`. :math:`h_y` refers to the predicted label when the logits output is :math:`y`. Args: eta_prime (float): the trade-off hyper parameter :math:`\eta'`. Inputs: - y_l: logits output :math:`y^l` by the main head on labeled data - y_l_adv: logits output :math:`y^l_{adv}` by the worst-case estimation head on labeled data - y_u: logits output :math:`y^u` by the main head on unlabeled data - y_u_adv: logits output :math:`y^u_{adv}` by the worst-case estimation head on unlabeled data Shape: - Inputs: :math:`(minibatch, C)` where C denotes the number of classes. - Output: scalar. """ def __init__(self, eta_prime): super(WorstCaseEstimationLoss, self).__init__() self.eta_prime = eta_prime def forward(self, y_l, y_l_adv, y_u, y_u_adv): _, prediction_l = y_l.max(dim=1) loss_l = self.eta_prime * F.cross_entropy(y_l_adv, prediction_l) _, prediction_u = y_u.max(dim=1) loss_u = F.nll_loss(shift_log(1. - F.softmax(y_u_adv, dim=1)), prediction_u) return loss_l + loss_u ================================================ FILE: tllib/self_training/flexmatch.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from collections import Counter import torch class DynamicThresholdingModule(object): r""" Dynamic thresholding module from `FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling `_. At time :math:`t`, for each category :math:`c`, the learning status :math:`\sigma_t(c)` is estimated by the number of samples whose predictions fall into this class and above a threshold (e.g. 0.95). Then, FlexMatch normalizes :math:`\sigma_t(c)` to make its range between 0 and 1 .. math:: \beta_t(c) = \frac{\sigma_t(c)}{\underset{c'}{\text{max}}~\sigma_t(c')}. The dynamic threshold is formulated as .. math:: \mathcal{T}_t(c) = \mathcal{M}(\beta_t(c)) \cdot \tau, where \tau denotes the pre-defined threshold (e.g. 0.95), :math:`\mathcal{M}` denotes a (possibly non-linear) mapping function. Args: threshold (float): The pre-defined confidence threshold warmup (bool): Whether perform threshold warm-up. If True, the number of unlabeled data that have not been used will be considered when normalizing :math:`\sigma_t(c)` mapping_func (callable): An increasing mapping function. For example, this function can be (1) concave :math:`\mathcal{M}(x)=\text{ln}(x+1)/\text{ln}2`, (2) linear :math:`\mathcal{M}(x)=x`, and (3) convex :math:`\mathcal{M}(x)=2/2-x` num_classes (int): Number of classes n_unlabeled_samples (int): Size of the unlabeled dataset device (torch.device): Device """ def __init__(self, threshold, warmup, mapping_func, num_classes, n_unlabeled_samples, device): self.threshold = threshold self.warmup = warmup self.mapping_func = mapping_func self.num_classes = num_classes self.n_unlabeled_samples = n_unlabeled_samples self.net_outputs = torch.zeros(n_unlabeled_samples, dtype=torch.long).to(device) self.net_outputs.fill_(-1) self.device = device def get_threshold(self, pseudo_labels): """Calculate and return dynamic threshold""" pseudo_counter = Counter(self.net_outputs.tolist()) if max(pseudo_counter.values()) == self.n_unlabeled_samples: # In the early stage of training, the network does not output pseudo labels with high confidence. # In this case, the learning status of all categories is simply zero. status = torch.zeros(self.num_classes).to(self.device) else: if not self.warmup and -1 in pseudo_counter.keys(): pseudo_counter.pop(-1) max_num = max(pseudo_counter.values()) # estimate learning status status = [ pseudo_counter[c] / max_num for c in range(self.num_classes) ] status = torch.FloatTensor(status).to(self.device) # calculate dynamic threshold dynamic_threshold = self.threshold * self.mapping_func(status[pseudo_labels]) return dynamic_threshold def update(self, idxes, selected_mask, pseudo_labels): """Update the learning status Args: idxes (tensor): Indexes of corresponding samples selected_mask (tensor): A binary mask, a value of 1 indicates the prediction for this sample will be updated pseudo_labels (tensor): Network predictions """ if idxes[selected_mask == 1].nelement() != 0: self.net_outputs[idxes[selected_mask == 1]] = pseudo_labels[selected_mask == 1] ================================================ FILE: tllib/self_training/mcc.py ================================================ """ @author: Ying Jin @contact: sherryying003@gmail.com """ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from tllib.modules.classifier import Classifier as ClassifierBase from ..modules.entropy import entropy __all__ = ['MinimumClassConfusionLoss', 'ImageClassifier'] class MinimumClassConfusionLoss(nn.Module): r""" Minimum Class Confusion loss minimizes the class confusion in the target predictions. You can see more details in `Minimum Class Confusion for Versatile Domain Adaptation (ECCV 2020) `_ Args: temperature (float) : The temperature for rescaling, the prediction will shrink to vanilla softmax if temperature is 1.0. .. note:: Make sure that temperature is larger than 0. Inputs: g_t - g_t (tensor): unnormalized classifier predictions on target domain, :math:`g^t` Shape: - g_t: :math:`(minibatch, C)` where C means the number of classes. - Output: scalar. Examples:: >>> temperature = 2.0 >>> loss = MinimumClassConfusionLoss(temperature) >>> # logits output from target domain >>> g_t = torch.randn(batch_size, num_classes) >>> output = loss(g_t) MCC can also serve as a regularizer for existing methods. Examples:: >>> from tllib.modules.domain_discriminator import DomainDiscriminator >>> num_classes = 2 >>> feature_dim = 1024 >>> batch_size = 10 >>> temperature = 2.0 >>> discriminator = DomainDiscriminator(in_feature=feature_dim, hidden_size=1024) >>> cdan_loss = ConditionalDomainAdversarialLoss(discriminator, reduction='mean') >>> mcc_loss = MinimumClassConfusionLoss(temperature) >>> # features from source domain and target domain >>> f_s, f_t = torch.randn(batch_size, feature_dim), torch.randn(batch_size, feature_dim) >>> # logits output from source domain adn target domain >>> g_s, g_t = torch.randn(batch_size, num_classes), torch.randn(batch_size, num_classes) >>> total_loss = cdan_loss(g_s, f_s, g_t, f_t) + mcc_loss(g_t) """ def __init__(self, temperature: float): super(MinimumClassConfusionLoss, self).__init__() self.temperature = temperature def forward(self, logits: torch.Tensor) -> torch.Tensor: batch_size, num_classes = logits.shape predictions = F.softmax(logits / self.temperature, dim=1) # batch_size x num_classes entropy_weight = entropy(predictions).detach() entropy_weight = 1 + torch.exp(-entropy_weight) entropy_weight = (batch_size * entropy_weight / torch.sum(entropy_weight)).unsqueeze(dim=1) # batch_size x 1 class_confusion_matrix = torch.mm((predictions * entropy_weight).transpose(1, 0), predictions) # num_classes x num_classes class_confusion_matrix = class_confusion_matrix / torch.sum(class_confusion_matrix, dim=1) mcc_loss = (torch.sum(class_confusion_matrix) - torch.trace(class_confusion_matrix)) / num_classes return mcc_loss class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) ================================================ FILE: tllib/self_training/mean_teacher.py ================================================ import copy from typing import Optional import torch def set_requires_grad(net, requires_grad=False): """ Set requires_grad=False for all the parameters to avoid unnecessary computations """ for param in net.parameters(): param.requires_grad = requires_grad class EMATeacher(object): r""" Exponential moving average model from `Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results (NIPS 2017) `_ We use :math:`\theta_t'` to denote parameters of the teacher model at training step t, use :math:`\theta_t` to denote parameters of the student model at training step t. Given decay factor :math:`\alpha`, we update the teacher model in an exponential moving average manner .. math:: \theta_t'=\alpha \theta_{t-1}' + (1-\alpha)\theta_t Args: model (torch.nn.Module): the student model alpha (float): decay factor for EMA. Inputs: x (tensor): input tensor Examples:: >>> classifier = ImageClassifier(backbone, num_classes=31, bottleneck_dim=256).to(device) >>> # initialize teacher model >>> teacher = EMATeacher(classifier, 0.9) >>> num_iterations = 1000 >>> for _ in range(num_iterations): >>> # x denotes input of one mini-batch >>> # you can get teacher model's output by teacher(x) >>> y_teacher = teacher(x) >>> # when you want to update teacher, you should call teacher.update() >>> teacher.update() """ def __init__(self, model, alpha): self.model = model self.alpha = alpha self.teacher = copy.deepcopy(model) set_requires_grad(self.teacher, False) def set_alpha(self, alpha: float): assert alpha >= 0 self.alpha = alpha def update(self): for teacher_param, param in zip(self.teacher.parameters(), self.model.parameters()): teacher_param.data = self.alpha * teacher_param + (1 - self.alpha) * param def __call__(self, x: torch.Tensor): return self.teacher(x) def train(self, mode: Optional[bool] = True): self.teacher.train(mode) def eval(self): self.train(False) def state_dict(self): return self.teacher.state_dict() def load_state_dict(self, state_dict): self.teacher.load_state_dict(state_dict) @property def module(self): return self.teacher.module def update_bn(model, ema_model): """ Replace batch normalization statistics of the teacher model with that ot the student model """ for m2, m1 in zip(ema_model.named_modules(), model.named_modules()): if ('bn' in m2[0]) and ('bn' in m1[0]): bn2, bn1 = m2[1].state_dict(), m1[1].state_dict() bn2['running_mean'].data.copy_(bn1['running_mean'].data) bn2['running_var'].data.copy_(bn1['running_var'].data) bn2['num_batches_tracked'].data.copy_(bn1['num_batches_tracked'].data) ================================================ FILE: tllib/self_training/pi_model.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from typing import Callable, Optional import numpy as np import torch from torch import nn as nn def sigmoid_warm_up(current_epoch, warm_up_epochs: int): """Exponential warm up function from `Temporal Ensembling for Semi-Supervised Learning (ICLR 2017) `_. """ assert warm_up_epochs >= 0 if warm_up_epochs == 0: return 1.0 else: current_epoch = np.clip(current_epoch, 0.0, warm_up_epochs) process = 1.0 - current_epoch / warm_up_epochs return float(np.exp(-5.0 * process * process)) class ConsistencyLoss(nn.Module): r""" Consistency loss between two predictions. Given distance measure :math:`D`, predictions :math:`p_1, p_2`, binary mask :math:`mask`, the consistency loss is .. math:: D(p_1, p_2) * mask Args: distance_measure (callable): Distance measure function. reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - p1: the first prediction - p2: the second prediction - mask: binary mask. Default: 1. (use all samples when calculating loss) Shape: - p1, p2: :math:`(N, C)` where C means the number of classes. - mask: :math:`(N, )` where N means mini-batch size. """ def __init__(self, distance_measure: Callable, reduction: Optional[str] = 'mean'): super(ConsistencyLoss, self).__init__() self.distance_measure = distance_measure self.reduction = reduction def forward(self, p1: torch.Tensor, p2: torch.Tensor, mask=1.): cons_loss = self.distance_measure(p1, p2) cons_loss = cons_loss * mask if self.reduction == 'mean': return cons_loss.mean() elif self.reduction == 'sum': return cons_loss.sum() else: return cons_loss class L2ConsistencyLoss(ConsistencyLoss): r""" L2 consistency loss. Given two predictions :math:`p_1, p_2` and binary mask :math:`mask`, the L2 consistency loss is .. math:: \text{MSELoss}(p_1, p_2) * mask """ def __init__(self, reduction: Optional[str] = 'mean'): def l2_distance(p1: torch.Tensor, p2: torch.Tensor): return ((p1 - p2) ** 2).sum(dim=1) super(L2ConsistencyLoss, self).__init__(l2_distance, reduction) ================================================ FILE: tllib/self_training/pseudo_label.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch.nn as nn import torch.nn.functional as F class ConfidenceBasedSelfTrainingLoss(nn.Module): """ Self training loss that adopts confidence threshold to select reliable pseudo labels from `Pseudo-Label : The Simple and Efficient Semi-Supervised Learning Method for Deep Neural Networks (ICML 2013) `_. Args: threshold (float): Confidence threshold. Inputs: - y: unnormalized classifier predictions. - y_target: unnormalized classifier predictions which will used for generating pseudo labels. Returns: A tuple, including - self_training_loss: self training loss with pseudo labels. - mask: binary mask that indicates which samples are retained (whose confidence is above the threshold). - pseudo_labels: generated pseudo labels. Shape: - y, y_target: :math:`(minibatch, C)` where C means the number of classes. - self_training_loss: scalar. - mask, pseudo_labels :math:`(minibatch, )`. """ def __init__(self, threshold: float): super(ConfidenceBasedSelfTrainingLoss, self).__init__() self.threshold = threshold def forward(self, y, y_target): confidence, pseudo_labels = F.softmax(y_target.detach(), dim=1).max(dim=1) mask = (confidence > self.threshold).float() self_training_loss = (F.cross_entropy(y, pseudo_labels, reduction='none') * mask).mean() return self_training_loss, mask, pseudo_labels ================================================ FILE: tllib/self_training/self_ensemble.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from tllib.modules.classifier import Classifier as ClassifierBase class ClassBalanceLoss(nn.Module): r""" Class balance loss that penalises the network for making predictions that exhibit large class imbalance. Given predictions :math:`p` with dimension :math:`(N, C)`, we first calculate the mini-batch mean per-class probability :math:`p_{mean}` with dimension :math:`(C, )`, where .. math:: p_{mean}^j = \frac{1}{N} \sum_{i=1}^N p_i^j Then we calculate binary cross entropy loss between :math:`p_{mean}` and uniform probability vector :math:`u` with the same dimension where :math:`u^j` = :math:`\frac{1}{C}` .. math:: loss = \text{BCELoss}(p_{mean}, u) Args: num_classes (int): Number of classes Inputs: - p (tensor): predictions from classifier Shape: - p: :math:`(N, C)` where C means the number of classes. """ def __init__(self, num_classes): super(ClassBalanceLoss, self).__init__() self.uniform_distribution = torch.ones(num_classes) / num_classes def forward(self, p: torch.Tensor): return F.binary_cross_entropy(p.mean(dim=0), self.uniform_distribution.to(p.device)) class ImageClassifier(ClassifierBase): def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim: Optional[int] = 256, **kwargs): bottleneck = nn.Sequential( # nn.AdaptiveAvgPool2d(output_size=(1, 1)), # nn.Flatten(), nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU() ) super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) ================================================ FILE: tllib/self_training/self_tuning.py ================================================ """ Adapted from https://github.com/thuml/Self-Tuning/tree/master @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch import torch.nn as nn from torch.nn.functional import normalize from tllib.modules.classifier import Classifier as ClassifierBase class Classifier(ClassifierBase): """Classifier class for Self-Tuning. Args: backbone (torch.nn.Module): Any backbone to extract 2-d features from data num_classes (int): Number of classes. projection_dim (int, optional): Dimension of the projector head. Default: 128 finetune (bool): Whether finetune the classifier or train from scratch. Default: True Inputs: - x (tensor): input data fed to `backbone` Outputs: In the training mode, - h: projections - y: classifier's predictions In the eval mode, - y: classifier's predictions Shape: - Inputs: (minibatch, *) where * means, any number of additional dimensions - y: (minibatch, `num_classes`) - h: (minibatch, `projection_dim`) """ def __init__(self, backbone: nn.Module, num_classes: int, projection_dim=1024, bottleneck_dim=1024, finetune=True, pool_layer=None): bottleneck = nn.Sequential( nn.Linear(backbone.out_features, bottleneck_dim), nn.BatchNorm1d(bottleneck_dim), nn.ReLU(), nn.Dropout(0.5) ) bottleneck[0].weight.data.normal_(0, 0.005) bottleneck[0].bias.data.fill_(0.1) head = nn.Linear(1024, num_classes) super(Classifier, self).__init__(backbone, num_classes=num_classes, head=head, finetune=finetune, pool_layer=pool_layer, bottleneck=bottleneck, bottleneck_dim=bottleneck_dim) self.projector = nn.Linear(1024, projection_dim) self.projection_dim = projection_dim def forward(self, x: torch.Tensor): f = self.pool_layer(self.backbone(x)) f = self.bottleneck(f) # projections h = self.projector(f) h = normalize(h, dim=1) # predictions predictions = self.head(f) if self.training: return h, predictions else: return predictions def get_parameters(self, base_lr=1.0): params = [ {"params": self.backbone.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head.parameters(), "lr": 1.0 * base_lr}, {"params": self.projector.parameters(), "lr": 0.1 * base_lr if self.finetune else 1.0 * base_lr}, ] return params class SelfTuning(nn.Module): r"""Self-Tuning module in `Self-Tuning for Data-Efficient Deep Learning (self-tuning, ICML 2021) `_. Args: encoder_q (Classifier): Query encoder. encoder_k (Classifier): Key encoder. num_classes (int): Number of classes K (int): Queue size. Default: 32 m (float): Momentum coefficient. Default: 0.999 T (float): Temperature. Default: 0.07 Inputs: - im_q (tensor): input data fed to `encoder_q` - im_k (tensor): input data fed to `encoder_k` - labels (tensor): classification labels of input data Outputs: pgc_logits, pgc_labels, y_q - pgc_logits: projector's predictions on both positive and negative samples - pgc_labels: contrastive labels - y_q: query classifier's predictions Shape: - im_q, im_k: (minibatch, *) where * means, any number of additional dimensions - labels: (minibatch, ) - y_q: (minibatch, `num_classes`) - pgc_logits: (minibatch, 1 + `num_classes` :math:`\times` `K`, `projection_dim`) - pgc_labels: (minibatch, 1 + `num_classes` :math:`\times` `K`) """ def __init__(self, encoder_q, encoder_k, num_classes, K=32, m=0.999, T=0.07): super(SelfTuning, self).__init__() self.K = K self.m = m self.T = T self.num_classes = num_classes # create the encoders # num_classes is the output fc dimension self.encoder_q = encoder_q self.encoder_k = encoder_k for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data.copy_(param_q.data) param_k.requires_grad = False # create the queue self.register_buffer("queue_list", torch.randn(encoder_q.projection_dim, K * self.num_classes)) self.queue_list = normalize(self.queue_list, dim=0) self.register_buffer("queue_ptr", torch.zeros(self.num_classes, dtype=torch.long)) @torch.no_grad() def _momentum_update_key_encoder(self): """ Momentum update of the key encoder """ for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) @torch.no_grad() def _dequeue_and_enqueue(self, h, label): # gather keys before updating queue batch_size = h.shape[0] ptr = int(self.queue_ptr[label]) real_ptr = ptr + label * self.K # replace the keys at ptr (dequeue and enqueue) self.queue_list[:, real_ptr:real_ptr + batch_size] = h.T # move pointer ptr = (ptr + batch_size) % self.K self.queue_ptr[label] = ptr def forward(self, im_q, im_k, labels): batch_size = im_q.size(0) device = im_q.device # compute query features h_q, y_q = self.encoder_q(im_q) # queries: h_q (N x projection_dim) # compute key features with torch.no_grad(): # no gradient to keys self._momentum_update_key_encoder() # update the key encoder h_k, _ = self.encoder_k(im_k) # keys: h_k (N x projection_dim) # compute logits # positive logits: Nx1 logits_pos = torch.einsum('nl,nl->n', [h_q, h_k]).unsqueeze(-1) # Einstein sum is more intuitive # cur_queue_list: queue_size * class_num cur_queue_list = self.queue_list.clone().detach() logits_neg_list = torch.Tensor([]).to(device) logits_pos_list = torch.Tensor([]).to(device) for i in range(batch_size): neg_sample = torch.cat([cur_queue_list[:, 0:labels[i] * self.K], cur_queue_list[:, (labels[i] + 1) * self.K:]], dim=1) pos_sample = cur_queue_list[:, labels[i] * self.K: (labels[i] + 1) * self.K] ith_neg = torch.einsum('nl,lk->nk', [h_q[i:i + 1], neg_sample]) ith_pos = torch.einsum('nl,lk->nk', [h_q[i:i + 1], pos_sample]) logits_neg_list = torch.cat((logits_neg_list, ith_neg), dim=0) logits_pos_list = torch.cat((logits_pos_list, ith_pos), dim=0) self._dequeue_and_enqueue(h_k[i:i + 1], labels[i]) # logits: 1 + queue_size + queue_size * (class_num - 1) pgc_logits = torch.cat([logits_pos, logits_pos_list, logits_neg_list], dim=1) pgc_logits = nn.LogSoftmax(dim=1)(pgc_logits / self.T) pgc_labels = torch.zeros([batch_size, 1 + self.K * self.num_classes]).to(device) pgc_labels[:, 0:self.K + 1].fill_(1.0 / (self.K + 1)) return pgc_logits, pgc_labels, y_q ================================================ FILE: tllib/self_training/uda.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch.nn as nn import torch.nn.functional as F class StrongWeakConsistencyLoss(nn.Module): """ Consistency loss between strong and weak augmented samples from `Unsupervised Data Augmentation for Consistency Training (NIPS 2020) `_. Args: threshold (float): Confidence threshold. temperature (float): Temperature. Inputs: - y_strong: unnormalized classifier predictions on strong augmented samples. - y: unnormalized classifier predictions on weak augmented samples. Shape: - y, y_strong: :math:`(minibatch, C)` where C means the number of classes. - Output: scalar. """ def __init__(self, threshold: float, temperature: float): super(StrongWeakConsistencyLoss, self).__init__() self.threshold = threshold self.temperature = temperature def forward(self, y_strong, y): confidence, _ = F.softmax(y.detach(), dim=1).max(dim=1) mask = (confidence > self.threshold).float() log_prob = F.log_softmax(y_strong / self.temperature, dim=1) con_loss = (F.kl_div(log_prob, F.softmax(y.detach(), dim=1), reduction='none').sum(dim=1)) con_loss = (con_loss * mask).sum() / max(mask.sum(), 1) return con_loss ================================================ FILE: tllib/translation/__init__.py ================================================ ================================================ FILE: tllib/translation/cycada.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn from torch import Tensor class SemanticConsistency(nn.Module): """ Semantic consistency loss is introduced by `CyCADA: Cycle-Consistent Adversarial Domain Adaptation (ICML 2018) `_ This helps to prevent label flipping during image translation. Args: ignore_index (tuple, optional): Specifies target values that are ignored and do not contribute to the input gradient. When :attr:`size_average` is ``True``, the loss is averaged over non-ignored targets. Default: (). reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the weighted mean of the output is taken, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` Shape: - Input: :math:`(N, C)` where `C = number of classes`, or :math:`(N, C, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of `K`-dimensional loss. - Target: :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. - Output: scalar. If :attr:`reduction` is ``'none'``, then the same size as the target: :math:`(N)`, or :math:`(N, d_1, d_2, ..., d_K)` with :math:`K \geq 1` in the case of K-dimensional loss. Examples:: >>> loss = SemanticConsistency() >>> input = torch.randn(3, 5, requires_grad=True) >>> target = torch.empty(3, dtype=torch.long).random_(5) >>> output = loss(input, target) >>> output.backward() """ def __init__(self, ignore_index=(), reduction='mean'): super(SemanticConsistency, self).__init__() self.ignore_index = ignore_index self.loss = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction) def forward(self, input: Tensor, target: Tensor) -> Tensor: for class_idx in self.ignore_index: target[target == class_idx] = -1 return self.loss(input, target) ================================================ FILE: tllib/translation/cyclegan/__init__.py ================================================ from . import discriminator from . import generator from . import loss from . import transform from .discriminator import * from .generator import * from .loss import * from .transform import * ================================================ FILE: tllib/translation/cyclegan/discriminator.py ================================================ """ Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn from torch.nn import init import functools from .util import get_norm_layer, init_weights class NLayerDiscriminator(nn.Module): """Construct a PatchGAN discriminator Args: input_nc (int): the number of channels in input images. ndf (int): the number of filters in the last conv layer. Default: 64 n_layers (int): the number of conv layers in the discriminator. Default: 3 norm_layer (torch.nn.Module): normalization layer. Default: :class:`nn.BatchNorm2d` """ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): super(NLayerDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): return self.model(input) class PixelDiscriminator(nn.Module): """Construct a 1x1 PatchGAN discriminator (pixelGAN) Args: input_nc (int): the number of channels in input images. ndf (int): the number of filters in the last conv layer. Default: 64 norm_layer (torch.nn.Module): normalization layer. Default: :class:`nn.BatchNorm2d` """ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): super(PixelDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.net = [ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), norm_layer(ndf * 2), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] self.net = nn.Sequential(*self.net) def forward(self, input): return self.net(input) def patch(ndf, input_nc=3, norm='batch', n_layers=3, init_type='normal', init_gain=0.02): """ PatchGAN classifier described in the original pix2pix paper. It can classify whether 70×70 overlapping patches are real or fake. Such a patch-level discriminator architecture has fewer parameters than a full-image discriminator and can work on arbitrarily-sized images in a fully convolutional fashion. Args: ndf (int): the number of filters in the first conv layer input_nc (int): the number of channels in input images. Default: 3 norm (str): the type of normalization layers used in the network. Default: 'batch' n_layers (int): the number of conv layers in the discriminator. Default: 3 init_type (str): the name of the initialization method. Choices includes: ``normal`` | ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal' init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02 """ norm_layer = get_norm_layer(norm_type=norm) net = NLayerDiscriminator(input_nc, ndf, n_layers=n_layers, norm_layer=norm_layer) init_weights(net, init_type, init_gain=init_gain) return net def pixel(ndf, input_nc=3, norm='batch', init_type='normal', init_gain=0.02): """ 1x1 PixelGAN discriminator can classify whether a pixel is real or not. It encourages greater color diversity but has no effect on spatial statistics. Args: ndf (int): the number of filters in the first conv layer input_nc (int): the number of channels in input images. Default: 3 norm (str): the type of normalization layers used in the network. Default: 'batch' init_type (str): the name of the initialization method. Choices includes: ``normal`` | ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal' init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02 """ norm_layer = get_norm_layer(norm_type=norm) net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) init_weights(net, init_type, init_gain=init_gain) return net ================================================ FILE: tllib/translation/cyclegan/generator.py ================================================ """ Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch import torch.nn as nn import functools from .util import get_norm_layer, init_weights class ResnetBlock(nn.Module): """Define a Resnet block""" def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): """Initialize the Resnet block A resnet block is a conv block with skip connections We construct a conv block with build_conv_block function, and implement skip connections in function. Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf """ super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): """Construct a convolutional block. Args: dim (int): the number of channels in the conv layer. padding_type (str): the name of padding layer: reflect | replicate | zero norm_layer (torch.nn.Module): normalization layer use_dropout (bool): if use dropout layers. use_bias (bool): if the conv layer uses bias or not Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) """ conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): """Forward function (with skip connections)""" out = x + self.conv_block(x) # add skip connections return out class ResnetGenerator(nn.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) """ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): """Construct a Resnet-based generator Args: input_nc (int): the number of channels in input images output_nc (int): the number of channels in output images ngf (int): the number of filters in the last conv layer norm_layer (torch.nn.Module): normalization layer use_dropout (bool): if use dropout layers n_blocks (int): the number of ResNet blocks padding_type (str): the name of padding layer in conv layers: reflect | replicate | zero """ assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(n_blocks): # add ResNet blocks model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): """Standard forward""" return self.model(input) class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): """Construct a Unet generator Args: input_nc (int): the number of channels in input images output_nc (int): the number of channels in output images num_downs (int): the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int): the number of filters in the last conv layer norm_layer(torch.nn.Module): normalization layer We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. """ super(UnetGenerator, self).__init__() # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer def forward(self, input): """Standard forward""" return self.model(input) class UnetSkipConnectionBlock(nn.Module): """Defines the Unet submodule with skip connection. X -------------------identity---------------------- |-- downsampling -- |submodule| -- upsampling --| """ def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): """Construct a Unet submodule with skip connections. Args: outer_nc (int): the number of filters in the outer conv layer inner_nc (int): the number of filters in the inner conv layer input_nc (int): the number of channels in input images/features submodule (UnetSkipConnectionBlock): previously defined submodules outermost (bool): if this module is the outermost module innermost (bool): if this module is the innermost module norm_layer (torch.nn.Module): normalization layer use_dropout (bool): if use dropout layers. """ super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: # add skip connections return torch.cat([x, self.model(x)], 1) def resnet_9(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02): """ Resnet-based generator with 9 Resnet blocks. Args: ngf (int): the number of filters in the last conv layer input_nc (int): the number of channels in input images. Default: 3 output_nc (int): the number of channels in output images. Default: 3 norm (str): the type of normalization layers used in the network. Default: 'batch' use_dropout (bool): whether use dropout. Default: False init_type (str): the name of the initialization method. Choices includes: ``normal`` | ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal' init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02 """ norm_layer = get_norm_layer(norm_type=norm) net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) init_weights(net, init_type, init_gain) return net def resnet_6(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02): """ Resnet-based generator with 6 Resnet blocks. Args: ngf (int): the number of filters in the last conv layer input_nc (int): the number of channels in input images. Default: 3 output_nc (int): the number of channels in output images. Default: 3 norm (str): the type of normalization layers used in the network. Default: 'batch' use_dropout (bool): whether use dropout. Default: False init_type (str): the name of the initialization method. Choices includes: ``normal`` | ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal' init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02 """ norm_layer = get_norm_layer(norm_type=norm) net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) init_weights(net, init_type, init_gain) return net def unet_256(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02): """ `U-Net `_ generator for 256x256 input images. The size of the input image should be a multiple of 256. Args: ngf (int): the number of filters in the last conv layer input_nc (int): the number of channels in input images. Default: 3 output_nc (int): the number of channels in output images. Default: 3 norm (str): the type of normalization layers used in the network. Default: 'batch' use_dropout (bool): whether use dropout. Default: False init_type (str): the name of the initialization method. Choices includes: ``normal`` | ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal' init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02 """ norm_layer = get_norm_layer(norm_type=norm) net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) init_weights(net, init_type, init_gain) return net def unet_128(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02): """ `U-Net `_ generator for 128x128 input images. The size of the input image should be a multiple of 128. Args: ngf (int): the number of filters in the last conv layer input_nc (int): the number of channels in input images. Default: 3 output_nc (int): the number of channels in output images. Default: 3 norm (str): the type of normalization layers used in the network. Default: 'batch' use_dropout (bool): whether use dropout. Default: False init_type (str): the name of the initialization method. Choices includes: ``normal`` | ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal' init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02 """ norm_layer = get_norm_layer(norm_type=norm) net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) init_weights(net, init_type, init_gain) return net def unet_32(ngf, input_nc=3, output_nc=3, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02): """ `U-Net `_ generator for 32x32 input images Args: ngf (int): the number of filters in the last conv layer input_nc (int): the number of channels in input images. Default: 3 output_nc (int): the number of channels in output images. Default: 3 norm (str): the type of normalization layers used in the network. Default: 'batch' use_dropout (bool): whether use dropout. Default: False init_type (str): the name of the initialization method. Choices includes: ``normal`` | ``xavier`` | ``kaiming`` | ``orthogonal``. Default: 'normal' init_gain (float): scaling factor for normal, xavier and orthogonal. Default: 0.02 """ norm_layer = get_norm_layer(norm_type=norm) net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout) init_weights(net, init_type, init_gain) return net ================================================ FILE: tllib/translation/cyclegan/loss.py ================================================ """ Modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn import torch class LeastSquaresGenerativeAdversarialLoss(nn.Module): """ Loss for `Least Squares Generative Adversarial Network (LSGAN) `_ Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - prediction (tensor): unnormalized discriminator predictions - real (bool): if the ground truth label is for real images or fake images. Default: true .. warning:: Do not use sigmoid as the last layer of Discriminator. """ def __init__(self, reduction='mean'): super(LeastSquaresGenerativeAdversarialLoss, self).__init__() self.mse_loss = nn.MSELoss(reduction=reduction) def forward(self, prediction, real=True): if real: label = torch.ones_like(prediction) else: label = torch.zeros_like(prediction) return self.mse_loss(prediction, label) class VanillaGenerativeAdversarialLoss(nn.Module): """ Loss for `Vanilla Generative Adversarial Network `_ Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - prediction (tensor): unnormalized discriminator predictions - real (bool): if the ground truth label is for real images or fake images. Default: true .. warning:: Do not use sigmoid as the last layer of Discriminator. """ def __init__(self, reduction='mean'): super(VanillaGenerativeAdversarialLoss, self).__init__() self.bce_loss = nn.BCEWithLogitsLoss(reduction=reduction) def forward(self, prediction, real=True): if real: label = torch.ones_like(prediction) else: label = torch.zeros_like(prediction) return self.bce_loss(prediction, label) class WassersteinGenerativeAdversarialLoss(nn.Module): """ Loss for `Wasserstein Generative Adversarial Network `_ Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'`` Inputs: - prediction (tensor): unnormalized discriminator predictions - real (bool): if the ground truth label is for real images or fake images. Default: true .. warning:: Do not use sigmoid as the last layer of Discriminator. """ def __init__(self, reduction='mean'): super(WassersteinGenerativeAdversarialLoss, self).__init__() self.mse_loss = nn.MSELoss(reduction=reduction) def forward(self, prediction, real=True): if real: return -prediction.mean() else: return prediction.mean() ================================================ FILE: tllib/translation/cyclegan/transform.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch import torch.nn as nn import torchvision.transforms as T from tllib.vision.transforms import Denormalize class Translation(nn.Module): """ Image Translation Transform Module Args: generator (torch.nn.Module): An image generator, e.g. :meth:`~tllib.translation.cyclegan.resnet_9_generator` device (torch.device): device to put the generator. Default: 'cpu' mean (tuple): the normalized mean for image std (tuple): the normalized std for image Input: - image (PIL.Image): raw image in shape H x W x C Output: raw image in shape H x W x 3 """ def __init__(self, generator, device=torch.device("cpu"), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)): super(Translation, self).__init__() self.generator = generator.to(device) self.device = device self.pre_process = T.Compose([ T.ToTensor(), T.Normalize(mean, std) ]) self.post_process = T.Compose([ Denormalize(mean, std), T.ToPILImage() ]) def forward(self, image): image = self.pre_process(image.copy()) # C x H x W image = image.to(self.device) generated_image = self.generator(image.unsqueeze(dim=0)).squeeze(dim=0).cpu() return self.post_process(generated_image) ================================================ FILE: tllib/translation/cyclegan/util.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn import functools import random import torch from torch.nn import init class Identity(nn.Module): def forward(self, x): return x def get_norm_layer(norm_type='instance'): """Return a normalization layer Parameters: norm_type (str) -- the name of the normalization layer: batch | instance | none For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. """ if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) elif norm_type == 'none': def norm_layer(x): return Identity() else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def init_weights(net, init_type='normal', init_gain=0.02): """Initialize network weights. Args: net (torch.nn.Module): network to be initialized init_type (str): the name of an initialization method. Choices includes: ``normal`` | ``xavier`` | ``kaiming`` | ``orthogonal`` init_gain (float): scaling factor for normal, xavier and orthogonal. 'normal' is used in the original CycleGAN paper. But xavier and kaiming might work better for some applications. """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function class ImagePool: """An image buffer that stores previously generated images. This buffer enables us to update discriminators using a history of generated images rather than the ones produced by the latest generators. Args: pool_size (int): the size of image buffer, if pool_size=0, no buffer will be created """ def __init__(self, pool_size): self.pool_size = pool_size if self.pool_size > 0: # create an empty pool self.num_imgs = 0 self.images = [] def query(self, images): """Return an image from the pool. Args: images (torch.Tensor): the latest generated images from the generator Returns: By 50/100, the buffer will return input images. By 50/100, the buffer will return images previously stored in the buffer, and insert the current images to the buffer. """ if self.pool_size == 0: # if the buffer size is 0, do nothing return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer self.num_imgs = self.num_imgs + 1 self.images.append(image) return_images.append(image) else: p = random.uniform(0, 1) if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer random_id = random.randint(0, self.pool_size - 1) # randint is inclusive tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) else: # by another 50% chance, the buffer will return the current image return_images.append(image) return_images = torch.cat(return_images, 0) # collect all the images and return return return_images def set_requires_grad(net, requires_grad=False): """ Set requies_grad=Fasle for all the networks to avoid unnecessary computations """ for param in net.parameters(): param.requires_grad = requires_grad ================================================ FILE: tllib/translation/fourier_transform.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import numpy as np import os import tqdm import random from PIL import Image from typing import Optional, Sequence import torch.nn as nn def low_freq_mutate(amp_src: np.ndarray, amp_trg: np.ndarray, beta: Optional[int] = 1): """ Args: amp_src (numpy.ndarray): amplitude component of the Fourier transform of source image amp_trg (numpy.ndarray): amplitude component of the Fourier transform of target image beta (int, optional): the size of the center region to be replace. Default: 1 Returns: amplitude component of the Fourier transform of source image whose low-frequency component is replaced by that of the target image. """ # Shift the zero-frequency component to the center of the spectrum. a_src = np.fft.fftshift(amp_src, axes=(-2, -1)) a_trg = np.fft.fftshift(amp_trg, axes=(-2, -1)) # The low-frequency component includes # the area where the horizontal and vertical distance from the center does not exceed beta _, h, w = a_src.shape c_h = np.floor(h / 2.0).astype(int) c_w = np.floor(w / 2.0).astype(int) h1 = c_h - beta h2 = c_h + beta + 1 w1 = c_w - beta w2 = c_w + beta + 1 # The low-frequency component of source amplitude is replaced by the target amplitude a_src[:, h1:h2, w1:w2] = a_trg[:, h1:h2, w1:w2] a_src = np.fft.ifftshift(a_src, axes=(-2, -1)) return a_src class FourierTransform(nn.Module): """ Fourier Transform is introduced by `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) `_ Fourier Transform replace the low frequency component of the amplitude of the source image to that of the target image. Denote with :math:`M_{β}` a mask, whose value is zero except for the center region: .. math:: M_{β}(h,w) = \mathbb{1}_{(h, w)\in [-β,β, -β, β]} Given images :math:`x^s` from source domain and :math:`x^t` from target domain, the source image in the target style is .. math:: x^{s→t} = \mathcal{F}^{-1}([ M_{β}\circ\mathcal{F}^A(x^t) + (1-M_{β})\circ\mathcal{F}^A(x^s), \mathcal{F}^P(x^s) ]) where :math:`\mathcal{F}^A`, :math:`\mathcal{F}^P` are the amplitude and phase component of the Fourier Transform :math:`\mathcal{F}` of an RGB image. Args: image_list (sequence[str]): A sequence of image list from the target domain. amplitude_dir (str): Specifies the directory to put the amplitude component of the target image. beta (int, optional): :math:`β`. Default: 1. rebuild (bool, optional): whether rebuild the amplitude component of the target image in the given directory. Inputs: - image (PIL Image): image from the source domain, :math:`x^t`. Examples: >>> from tllib.translation.fourier_transform import FourierTransform >>> image_list = ["target_image_path1", "target_image_path2"] >>> amplitude_dir = "path/to/amplitude_dir" >>> fourier_transform = FourierTransform(image_list, amplitude_dir, beta=1, rebuild=False) >>> source_image = np.array((256, 256, 3)) # image form source domain >>> source_image_in_target_style = fourier_transform(source_image) .. note:: The meaning of :math:`β` is different from that of the origin paper. Experimentally, we found that the size of the center region in the frequency space should be constant when the image size increases. Thus we make the size of the center region independent of the image size. A recommended value for :math:`β` is 1. .. note:: The image structure of the source domain and target domain should be as similar as possible, thus for segemntation tasks, FourierTransform should be used before RandomResizeCrop and other transformations. .. note:: The image size of the source domain and the target domain need to be the same, thus before FourierTransform, you should use Resize to convert the source image to the target image size. Examples: >>> from tllib.translation.fourier_transform import FourierTransform >>> import tllibvision.datasets.segmentation.transforms as T >>> from PIL import Image >>> target_image_list = ["target_image_path1", "target_image_path2"] >>> amplitude_dir = "path/to/amplitude_dir" >>> # build a fourier transform that translate source images to the target style >>> fourier_transform = T.wrapper(FourierTransform)(target_image_list, amplitude_dir) >>> transforms=T.Compose([ ... # convert source image to the size of the target image before fourier transform ... T.Resize((2048, 1024)), ... fourier_transform, ... T.RandomResizedCrop((1024, 512)), ... T.RandomHorizontalFlip(), ... ]) >>> source_image = Image.open("path/to/source_image") # image form source domain >>> source_image_in_target_style = transforms(source_image) """ # TODO add image examples when beta is different def __init__(self, image_list: Sequence[str], amplitude_dir: str, beta: Optional[int] = 1, rebuild: Optional[bool] = False): super(FourierTransform, self).__init__() self.amplitude_dir = amplitude_dir if not os.path.exists(amplitude_dir) or rebuild: os.makedirs(amplitude_dir, exist_ok=True) self.build_amplitude(image_list, amplitude_dir) self.beta = beta self.length = len(image_list) @staticmethod def build_amplitude(image_list, amplitude_dir): # extract amplitudes from target domain for i, image_name in enumerate(tqdm.tqdm(image_list)): image = Image.open(image_name).convert('RGB') image = np.asarray(image, np.float32) image = image.transpose((2, 0, 1)) fft = np.fft.fft2(image, axes=(-2, -1)) amp = np.abs(fft) np.save(os.path.join(amplitude_dir, "{}.npy".format(i)), amp) def forward(self, image): # randomly sample a target image and load its amplitude component amp_trg = np.load(os.path.join(self.amplitude_dir, "{}.npy".format(random.randint(0, self.length-1)))) image = np.asarray(image, np.float32) image = image.transpose((2, 0, 1)) # get fft, amplitude on source domain fft_src = np.fft.fft2(image, axes=(-2, -1)) amp_src, pha_src = np.abs(fft_src), np.angle(fft_src) # mutate the amplitude part of source with target amp_src_ = low_freq_mutate(amp_src, amp_trg, beta=self.beta) # mutated fft of source fft_src_ = amp_src_ * np.exp(1j * pha_src) # get the mutated image src_in_trg = np.fft.ifft2(fft_src_, axes=(-2, -1)) src_in_trg = np.real(src_in_trg) src_in_trg = src_in_trg.transpose((1, 2, 0)) src_in_trg = Image.fromarray(src_in_trg.clip(min=0, max=255).astype('uint8')).convert('RGB') return src_in_trg ================================================ FILE: tllib/translation/spgan/__init__.py ================================================ from . import siamese from . import loss from .siamese import * from .loss import * ================================================ FILE: tllib/translation/spgan/loss.py ================================================ """ Modified from https://github.com/Simon4Yan/eSPGAN @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch import torch.nn.functional as F class ContrastiveLoss(torch.nn.Module): r"""Contrastive loss from `Dimensionality Reduction by Learning an Invariant Mapping (CVPR 2006) `_. Given output features :math:`f_1, f_2`, we use :math:`D` to denote the pairwise euclidean distance between them, :math:`Y` to denote the ground truth labels, :math:`m` to denote a pre-defined margin, then contrastive loss is calculated as .. math:: (1 - Y)\frac{1}{2}D^2 + (Y)\frac{1}{2}\{\text{max}(0, m-D)^2\} Args: margin (float, optional): margin for contrastive loss. Default: 2.0 Inputs: - output1 (tensor): feature representations of the first set of samples (:math:`f_1` here). - output2 (tensor): feature representations of the second set of samples (:math:`f_2` here). - label (tensor): labels (:math:`Y` here). Shape: - output1, output2: :math:`(minibatch, F)` where F means the dimension of input features. - label: :math:`(minibatch, )` """ def __init__(self, margin=2.0): super(ContrastiveLoss, self).__init__() self.margin = margin def forward(self, output1, output2, label): euclidean_distance = F.pairwise_distance(output1, output2) loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) return loss ================================================ FILE: tllib/translation/spgan/siamese.py ================================================ """ Modified from https://github.com/Simon4Yan/eSPGAN @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch.nn as nn import torch.nn.functional as F class ConvBlock(nn.Module): """Basic block with structure Conv-LeakyReLU->Pool""" def __init__(self, in_dim, out_dim): super(ConvBlock, self).__init__() self.conv_block = nn.Sequential( nn.Conv2d(in_dim, out_dim, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2), nn.MaxPool2d(kernel_size=2, stride=2) ) def forward(self, x): return self.conv_block(x) class SiameseNetwork(nn.Module): """Siamese network whose input is an image of shape :math:`(3,H,W)` and output is an one-dimensional feature vector. Args: nsf (int): dimension of output feature representation. """ def __init__(self, nsf=64): super(SiameseNetwork, self).__init__() self.conv = nn.Sequential( nn.Conv2d(3, nsf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2), nn.MaxPool2d(kernel_size=2, stride=2), ConvBlock(nsf, nsf * 2), ConvBlock(nsf * 2, nsf * 4), ) self.flatten = nn.Flatten() self.fc1 = nn.Linear(2048, nsf * 2, bias=False) self.leaky_relu = nn.LeakyReLU(0.2) self.dropout = nn.Dropout(0.5) self.fc2 = nn.Linear(nsf * 2, nsf, bias=False) def forward(self, x): x = self.flatten(self.conv(x)) x = self.fc1(x) x = self.leaky_relu(x) x = self.dropout(x) x = self.fc2(x) x = F.normalize(x) return x ================================================ FILE: tllib/utils/__init__.py ================================================ from .logger import CompleteLogger from .meter import * from .data import ForeverDataIterator __all__ = ['metric', 'analysis', 'meter', 'data', 'logger'] ================================================ FILE: tllib/utils/analysis/__init__.py ================================================ import torch from torch.utils.data import DataLoader import torch.nn as nn import tqdm def collect_feature(data_loader: DataLoader, feature_extractor: nn.Module, device: torch.device, max_num_features=None) -> torch.Tensor: """ Fetch data from `data_loader`, and then use `feature_extractor` to collect features Args: data_loader (torch.utils.data.DataLoader): Data loader. feature_extractor (torch.nn.Module): A feature extractor. device (torch.device) max_num_features (int): The max number of features to return Returns: Features in shape (min(len(data_loader), max_num_features * mini-batch size), :math:`|\mathcal{F}|`). """ feature_extractor.eval() all_features = [] with torch.no_grad(): for i, data in enumerate(tqdm.tqdm(data_loader)): if max_num_features is not None and i >= max_num_features: break inputs = data[0].to(device) feature = feature_extractor(inputs).cpu() all_features.append(feature) return torch.cat(all_features, dim=0) ================================================ FILE: tllib/utils/analysis/a_distance.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from torch.utils.data import TensorDataset import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader from torch.optim import SGD from ..meter import AverageMeter from ..metric import binary_accuracy class ANet(nn.Module): def __init__(self, in_feature): super(ANet, self).__init__() self.layer = nn.Linear(in_feature, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.layer(x) x = self.sigmoid(x) return x def calculate(source_feature: torch.Tensor, target_feature: torch.Tensor, device, progress=True, training_epochs=10): """ Calculate the :math:`\mathcal{A}`-distance, which is a measure for distribution discrepancy. The definition is :math:`dist_\mathcal{A} = 2 (1-2\epsilon)`, where :math:`\epsilon` is the test error of a classifier trained to discriminate the source from the target. Args: source_feature (tensor): features from source domain in shape :math:`(minibatch, F)` target_feature (tensor): features from target domain in shape :math:`(minibatch, F)` device (torch.device) progress (bool): if True, displays a the progress of training A-Net training_epochs (int): the number of epochs when training the classifier Returns: :math:`\mathcal{A}`-distance """ source_label = torch.ones((source_feature.shape[0], 1)) target_label = torch.zeros((target_feature.shape[0], 1)) feature = torch.cat([source_feature, target_feature], dim=0) label = torch.cat([source_label, target_label], dim=0) dataset = TensorDataset(feature, label) length = len(dataset) train_size = int(0.8 * length) val_size = length - train_size train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size]) train_loader = DataLoader(train_set, batch_size=2, shuffle=True) val_loader = DataLoader(val_set, batch_size=8, shuffle=False) anet = ANet(feature.shape[1]).to(device) optimizer = SGD(anet.parameters(), lr=0.01) a_distance = 2.0 for epoch in range(training_epochs): anet.train() for (x, label) in train_loader: x = x.to(device) label = label.to(device) anet.zero_grad() y = anet(x) loss = F.binary_cross_entropy(y, label) loss.backward() optimizer.step() anet.eval() meter = AverageMeter("accuracy", ":4.2f") with torch.no_grad(): for (x, label) in val_loader: x = x.to(device) label = label.to(device) y = anet(x) acc = binary_accuracy(y, label) meter.update(acc, x.shape[0]) error = 1 - meter.avg / 100 a_distance = 2 * (1 - 2 * error) if progress: print("epoch {} accuracy: {} A-dist: {}".format(epoch, meter.avg, a_distance)) return a_distance ================================================ FILE: tllib/utils/analysis/tsne.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch import matplotlib matplotlib.use('Agg') from sklearn.manifold import TSNE import numpy as np import matplotlib.pyplot as plt import matplotlib.colors as col def visualize(source_feature: torch.Tensor, target_feature: torch.Tensor, filename: str, source_color='r', target_color='b'): """ Visualize features from different domains using t-SNE. Args: source_feature (tensor): features from source domain in shape :math:`(minibatch, F)` target_feature (tensor): features from target domain in shape :math:`(minibatch, F)` filename (str): the file name to save t-SNE source_color (str): the color of the source features. Default: 'r' target_color (str): the color of the target features. Default: 'b' """ source_feature = source_feature.numpy() target_feature = target_feature.numpy() features = np.concatenate([source_feature, target_feature], axis=0) # map features to 2-d using TSNE X_tsne = TSNE(n_components=2, random_state=33).fit_transform(features) # domain labels, 1 represents source while 0 represents target domains = np.concatenate((np.ones(len(source_feature)), np.zeros(len(target_feature)))) # visualize using matplotlib fig, ax = plt.subplots(figsize=(10, 10)) ax.spines['top'].set_visible(False) ax.spines['right'].set_visible(False) ax.spines['bottom'].set_visible(False) ax.spines['left'].set_visible(False) plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=domains, cmap=col.ListedColormap([target_color, source_color]), s=20) plt.xticks([]) plt.yticks([]) plt.savefig(filename) ================================================ FILE: tllib/utils/data.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import itertools import random import numpy as np import torch from torch.utils.data import Sampler from torch.utils.data import DataLoader, Dataset from typing import TypeVar, Iterable, Dict, List T_co = TypeVar('T_co', covariant=True) T = TypeVar('T') def send_to_device(tensor, device): """ Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. Args: tensor (nested list/tuple/dictionary of :obj:`torch.Tensor`): The data to send to a given device. device (:obj:`torch.device`): The device to send the data to Returns: The same data structure as :obj:`tensor` with all tensors sent to the proper device. """ if isinstance(tensor, (list, tuple)): return type(tensor)(send_to_device(t, device) for t in tensor) elif isinstance(tensor, dict): return type(tensor)({k: send_to_device(v, device) for k, v in tensor.items()}) elif not hasattr(tensor, "to"): return tensor return tensor.to(device) class ForeverDataIterator: r"""A data iterator that will never stop producing data""" def __init__(self, data_loader: DataLoader, device=None): self.data_loader = data_loader self.iter = iter(self.data_loader) self.device = device def __next__(self): try: data = next(self.iter) if self.device is not None: data = send_to_device(data, self.device) except StopIteration: self.iter = iter(self.data_loader) data = next(self.iter) if self.device is not None: data = send_to_device(data, self.device) return data def __len__(self): return len(self.data_loader) class RandomMultipleGallerySampler(Sampler): r"""Sampler from `In defense of the Triplet Loss for Person Re-Identification (ICCV 2017) `_. Assume there are :math:`N` identities in the dataset, this implementation simply samples :math:`K` images for every identity to form an iter of size :math:`N\times K`. During training, we will call ``__iter__`` method of pytorch dataloader once we reach a ``StopIteration``, this guarantees every image in the dataset will eventually be selected and we are not wasting any training data. Args: dataset(list): each element of this list is a tuple (image_path, person_id, camera_id) num_instances(int, optional): number of images to sample for every identity (:math:`K` here) """ def __init__(self, dataset, num_instances=4): super(RandomMultipleGallerySampler, self).__init__(dataset) self.dataset = dataset self.num_instances = num_instances self.idx_to_pid = {} self.cid_list_per_pid = {} self.idx_list_per_pid = {} for idx, (_, pid, cid) in enumerate(dataset): if pid not in self.cid_list_per_pid: self.cid_list_per_pid[pid] = [] self.idx_list_per_pid[pid] = [] self.idx_to_pid[idx] = pid self.cid_list_per_pid[pid].append(cid) self.idx_list_per_pid[pid].append(idx) self.pid_list = list(self.idx_list_per_pid.keys()) self.num_samples = len(self.pid_list) def __len__(self): return self.num_samples * self.num_instances def __iter__(self): def select_idxes(element_list, target_element): assert isinstance(element_list, list) return [i for i, element in enumerate(element_list) if element != target_element] pid_idxes = torch.randperm(len(self.pid_list)).tolist() final_idxes = [] for perm_id in pid_idxes: i = random.choice(self.idx_list_per_pid[self.pid_list[perm_id]]) _, _, cid = self.dataset[i] final_idxes.append(i) pid_i = self.idx_to_pid[i] cid_list = self.cid_list_per_pid[pid_i] idx_list = self.idx_list_per_pid[pid_i] selected_cid_list = select_idxes(cid_list, cid) if selected_cid_list: if len(selected_cid_list) >= self.num_instances: cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=False) else: cid_idxes = np.random.choice(selected_cid_list, size=self.num_instances - 1, replace=True) for cid_idx in cid_idxes: final_idxes.append(idx_list[cid_idx]) else: selected_idxes = select_idxes(idx_list, i) if not selected_idxes: continue if len(selected_idxes) >= self.num_instances: pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=False) else: pid_idxes = np.random.choice(selected_idxes, size=self.num_instances - 1, replace=True) for pid_idx in pid_idxes: final_idxes.append(idx_list[pid_idx]) return iter(final_idxes) class CombineDataset(Dataset[T_co]): r"""Dataset as a combination of multiple datasets. The element of each dataset must be a list, and the i-th element of the combined dataset is a list splicing of the i-th element of each sub dataset. The length of the combined dataset is the minimum of the lengths of all sub datasets. Arguments: datasets (sequence): List of datasets to be concatenated """ def __init__(self, datasets: Iterable[Dataset]) -> None: super(CombineDataset, self).__init__() # Cannot verify that datasets is Sized assert len(datasets) > 0, 'datasets should not be an empty iterable' # type: ignore self.datasets = list(datasets) def __len__(self): return min([len(d) for d in self.datasets]) def __getitem__(self, idx): return list(itertools.chain(*[d[idx] for d in self.datasets])) def concatenate(tensors): """concatenate multiple batches into one batch. ``tensors`` can be :class:`torch.Tensor`, List or Dict, but they must be the same data format. """ if isinstance(tensors[0], torch.Tensor): return torch.cat(tensors, dim=0) elif isinstance(tensors[0], List): ret = [] for i in range(len(tensors[0])): ret.append(concatenate([t[i] for t in tensors])) return ret elif isinstance(tensors[0], Dict): ret = dict() for k in tensors[0].keys(): ret[k] = concatenate([t[k] for t in tensors]) return ret ================================================ FILE: tllib/utils/logger.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os import sys import time class TextLogger(object): """Writes stream output to external text file. Args: filename (str): the file to write stream output stream: the stream to read from. Default: sys.stdout """ def __init__(self, filename, stream=sys.stdout): self.terminal = stream self.log = open(filename, 'a') def write(self, message): self.terminal.write(message) self.log.write(message) self.flush() def flush(self): self.terminal.flush() self.log.flush() def close(self): self.terminal.close() self.log.close() class CompleteLogger: """ A useful logger that - writes outputs to files and displays them on the console at the same time. - manages the directory of checkpoints and debugging images. Args: root (str): the root directory of logger phase (str): the phase of training. """ def __init__(self, root, phase='train'): self.root = root self.phase = phase self.visualize_directory = os.path.join(self.root, "visualize") self.checkpoint_directory = os.path.join(self.root, "checkpoints") self.epoch = 0 os.makedirs(self.root, exist_ok=True) os.makedirs(self.visualize_directory, exist_ok=True) os.makedirs(self.checkpoint_directory, exist_ok=True) # redirect std out now = time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime(time.time())) log_filename = os.path.join(self.root, "{}-{}.txt".format(phase, now)) if os.path.exists(log_filename): os.remove(log_filename) self.logger = TextLogger(log_filename) sys.stdout = self.logger sys.stderr = self.logger if phase != 'train': self.set_epoch(phase) def set_epoch(self, epoch): """Set the epoch number. Please use it during training.""" os.makedirs(os.path.join(self.visualize_directory, str(epoch)), exist_ok=True) self.epoch = epoch def _get_phase_or_epoch(self): if self.phase == 'train': return str(self.epoch) else: return self.phase def get_image_path(self, filename: str): """ Get the full image path for a specific filename """ return os.path.join(self.visualize_directory, self._get_phase_or_epoch(), filename) def get_checkpoint_path(self, name=None): """ Get the full checkpoint path. Args: name (optional): the filename (without file extension) to save checkpoint. If None, when the phase is ``train``, checkpoint will be saved to ``{epoch}.pth``. Otherwise, will be saved to ``{phase}.pth``. """ if name is None: name = self._get_phase_or_epoch() name = str(name) return os.path.join(self.checkpoint_directory, name + ".pth") def close(self): self.logger.close() ================================================ FILE: tllib/utils/meter.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, List class AverageMeter(object): r"""Computes and stores the average and current value. Examples:: >>> # Initialize a meter to record loss >>> losses = AverageMeter() >>> # Update meter after every minibatch update >>> losses.update(loss_value, batch_size) """ def __init__(self, name: str, fmt: Optional[str] = ':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n if self.count > 0: self.avg = self.sum / self.count def __str__(self): fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' return fmtstr.format(**self.__dict__) class AverageMeterDict(object): def __init__(self, names: List, fmt: Optional[str] = ':f'): self.dict = { name: AverageMeter(name, fmt) for name in names } def reset(self): for meter in self.dict.values(): meter.reset() def update(self, accuracies, n=1): for name, acc in accuracies.items(): self.dict[name].update(acc, n) def average(self): return { name: meter.avg for name, meter in self.dict.items() } def __getitem__(self, item): return self.dict[item] class Meter(object): """Computes and stores the current value.""" def __init__(self, name: str, fmt: Optional[str] = ':f'): self.name = name self.fmt = fmt self.reset() def reset(self): self.val = 0 def update(self, val): self.val = val def __str__(self): fmtstr = '{name} {val' + self.fmt + '}' return fmtstr.format(**self.__dict__) class ProgressMeter(object): def __init__(self, num_batches, meters, prefix=""): self.batch_fmtstr = self._get_batch_fmtstr(num_batches) self.meters = meters self.prefix = prefix def display(self, batch): entries = [self.prefix + self.batch_fmtstr.format(batch)] entries += [str(meter) for meter in self.meters] print('\t'.join(entries)) def _get_batch_fmtstr(self, num_batches): num_digits = len(str(num_batches // 1)) fmt = '{:' + str(num_digits) + 'd}' return '[' + fmt + '/' + fmt.format(num_batches) + ']' ================================================ FILE: tllib/utils/metric/__init__.py ================================================ import torch import prettytable __all__ = ['keypoint_detection'] def binary_accuracy(output: torch.Tensor, target: torch.Tensor) -> float: """Computes the accuracy for binary classification""" with torch.no_grad(): batch_size = target.size(0) pred = (output >= 0.5).float().t().view(-1) correct = pred.eq(target.view(-1)).float().sum() correct.mul_(100. / batch_size) return correct def accuracy(output, target, topk=(1,)): r""" Computes the accuracy over the k top predictions for the specified values of k Args: output (tensor): Classification outputs, :math:`(N, C)` where `C = number of classes` target (tensor): :math:`(N)` where each value is :math:`0 \leq \text{targets}[i] \leq C-1` topk (sequence[int]): A list of top-N number. Returns: Top-N accuracies (N :math:`\in` topK). """ with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target[None]) res = [] for k in topk: correct_k = correct[:k].flatten().sum(dtype=torch.float32) res.append(correct_k * (100.0 / batch_size)) return res class ConfusionMatrix(object): def __init__(self, num_classes): self.num_classes = num_classes self.mat = None def update(self, target, output): """ Update confusion matrix. Args: target: ground truth output: predictions of models Shape: - target: :math:`(minibatch, C)` where C means the number of classes. - output: :math:`(minibatch, C)` where C means the number of classes. """ n = self.num_classes if self.mat is None: self.mat = torch.zeros((n, n), dtype=torch.int64, device=target.device) with torch.no_grad(): k = (target >= 0) & (target < n) inds = n * target[k].to(torch.int64) + output[k] self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) def reset(self): self.mat.zero_() def compute(self): """compute global accuracy, per-class accuracy and per-class IoU""" h = self.mat.float() acc_global = torch.diag(h).sum() / h.sum() acc = torch.diag(h) / h.sum(1) iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) return acc_global, acc, iu # def reduce_from_all_processes(self): # if not torch.distributed.is_available(): # return # if not torch.distributed.is_initialized(): # return # torch.distributed.barrier() # torch.distributed.all_reduce(self.mat) def __str__(self): acc_global, acc, iu = self.compute() return ( 'global correct: {:.1f}\n' 'average row correct: {}\n' 'IoU: {}\n' 'mean IoU: {:.1f}').format( acc_global.item() * 100, ['{:.1f}'.format(i) for i in (acc * 100).tolist()], ['{:.1f}'.format(i) for i in (iu * 100).tolist()], iu.mean().item() * 100) def format(self, classes: list): """Get the accuracy and IoU for each class in the table format""" acc_global, acc, iu = self.compute() table = prettytable.PrettyTable(["class", "acc", "iou"]) for i, class_name, per_acc, per_iu in zip(range(len(classes)), classes, (acc * 100).tolist(), (iu * 100).tolist()): table.add_row([class_name, per_acc, per_iu]) return 'global correct: {:.1f}\nmean correct:{:.1f}\nmean IoU: {:.1f}\n{}'.format( acc_global.item() * 100, acc.mean().item() * 100, iu.mean().item() * 100, table.get_string()) ================================================ FILE: tllib/utils/metric/keypoint_detection.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ # TODO: add documentation import numpy as np def get_max_preds(batch_heatmaps): ''' get predictions from score maps heatmaps: numpy.ndarray([batch_size, num_joints, height, width]) ''' assert isinstance(batch_heatmaps, np.ndarray), \ 'batch_heatmaps should be numpy.ndarray' assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim' batch_size = batch_heatmaps.shape[0] num_joints = batch_heatmaps.shape[1] width = batch_heatmaps.shape[3] heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1)) idx = np.argmax(heatmaps_reshaped, 2) maxvals = np.amax(heatmaps_reshaped, 2) maxvals = maxvals.reshape((batch_size, num_joints, 1)) idx = idx.reshape((batch_size, num_joints, 1)) preds = np.tile(idx, (1, 1, 2)).astype(np.float32) preds[:, :, 0] = (preds[:, :, 0]) % width preds[:, :, 1] = np.floor((preds[:, :, 1]) / width) pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2)) pred_mask = pred_mask.astype(np.float32) preds *= pred_mask return preds, maxvals def calc_dists(preds, target, normalize): preds = preds.astype(np.float32) target = target.astype(np.float32) dists = np.zeros((preds.shape[1], preds.shape[0])) for n in range(preds.shape[0]): for c in range(preds.shape[1]): if target[n, c, 0] > 1 and target[n, c, 1] > 1: normed_preds = preds[n, c, :] / normalize[n] normed_targets = target[n, c, :] / normalize[n] dists[c, n] = np.linalg.norm(normed_preds - normed_targets) else: dists[c, n] = -1 return dists def dist_acc(dists, thr=0.5): ''' Return percentage below threshold while ignoring values with a -1 ''' dist_cal = np.not_equal(dists, -1) num_dist_cal = dist_cal.sum() if num_dist_cal > 0: return np.less(dists[dist_cal], thr).sum() * 1.0 / num_dist_cal else: return -1 def accuracy(output, target, hm_type='gaussian', thr=0.5): ''' Calculate accuracy according to PCK, but uses ground truth heatmap rather than x,y locations First value to be returned is average accuracy across 'idxs', followed by individual accuracies ''' idx = list(range(output.shape[1])) norm = 1.0 if hm_type == 'gaussian': pred, _ = get_max_preds(output) target, _ = get_max_preds(target) h = output.shape[2] w = output.shape[3] norm = np.ones((pred.shape[0], 2)) * np.array([h, w]) / 10 dists = calc_dists(pred, target, norm) acc = np.zeros(len(idx)) avg_acc = 0 cnt = 0 for i in range(len(idx)): acc[i] = dist_acc(dists[idx[i]], thr) if acc[i] >= 0: avg_acc = avg_acc + acc[i] cnt += 1 avg_acc = avg_acc / cnt if cnt != 0 else 0 return acc, avg_acc, cnt, pred ================================================ FILE: tllib/utils/metric/reid.py ================================================ # TODO: add documentation """ Modified from https://github.com/yxgeee/MMT @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import os import os.path as osp from collections import defaultdict import time import numpy as np import torch import torch.nn.functional as F from sklearn.metrics import average_precision_score from tllib.utils.meter import AverageMeter, ProgressMeter def unique_sample(ids_dict, num): """Randomly choose one instance for each person id, these instances will not be selected again""" mask = np.zeros(num, dtype=np.bool) for _, indices in ids_dict.items(): i = np.random.choice(indices) mask[i] = True return mask def cmc(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams, topk=100, separate_camera_set=False, single_gallery_shot=False, first_match_break=False): """Compute Cumulative Matching Characteristics (CMC)""" dist_mat = dist_mat.cpu().numpy() m, n = dist_mat.shape query_ids = np.asarray(query_ids) gallery_ids = np.asarray(gallery_ids) query_cams = np.asarray(query_cams) gallery_cams = np.asarray(gallery_cams) # Sort and find correct matches indices = np.argsort(dist_mat, axis=1) matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) # Compute CMC for each query ret = np.zeros(topk) num_valid_queries = 0 for i in range(m): # Filter out the same id and same camera valid = ((gallery_ids[indices[i]] != query_ids[i]) | (gallery_cams[indices[i]] != query_cams[i])) if separate_camera_set: # Filter out samples from same camera valid &= (gallery_cams[indices[i]] != query_cams[i]) if not np.any(matches[i, valid]): continue if single_gallery_shot: repeat = 10 gids = gallery_ids[indices[i][valid]] inds = np.where(valid)[0] ids_dict = defaultdict(list) for j, x in zip(inds, gids): ids_dict[x].append(j) else: repeat = 1 for _ in range(repeat): if single_gallery_shot: # Randomly choose one instance for each id sampled = (valid & unique_sample(ids_dict, len(valid))) index = np.nonzero(matches[i, sampled])[0] else: index = np.nonzero(matches[i, valid])[0] delta = 1. / (len(index) * repeat) for j, k in enumerate(index): if k - j >= topk: break if first_match_break: ret[k - j] += 1 break ret[k - j] += delta num_valid_queries += 1 if num_valid_queries == 0: raise RuntimeError("No valid query") return ret.cumsum() / num_valid_queries def mean_ap(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams): """Compute mean average precision (mAP)""" dist_mat = dist_mat.cpu().numpy() m, n = dist_mat.shape query_ids = np.asarray(query_ids) gallery_ids = np.asarray(gallery_ids) query_cams = np.asarray(query_cams) gallery_cams = np.asarray(gallery_cams) # Sort and find correct matches indices = np.argsort(dist_mat, axis=1) matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) # Compute AP for each query aps = [] for i in range(m): # Filter out the same id and same camera valid = ((gallery_ids[indices[i]] != query_ids[i]) | (gallery_cams[indices[i]] != query_cams[i])) y_true = matches[i, valid] y_score = -dist_mat[i][indices[i]][valid] if not np.any(y_true): continue aps.append(average_precision_score(y_true, y_score)) if len(aps) == 0: raise RuntimeError("No valid query") return np.mean(aps) def re_ranking(q_g_dist, q_q_dist, g_g_dist, k1=20, k2=6, lambda_value=0.3): """Perform re-ranking with distance matrix between query and gallery images `q_g_dist`, distance matrix between query and query images `q_q_dist` and distance matrix between gallery and gallery images `g_g_dist`. """ q_g_dist = q_g_dist.cpu().numpy() q_q_dist = q_q_dist.cpu().numpy() g_g_dist = g_g_dist.cpu().numpy() original_dist = np.concatenate( [np.concatenate([q_q_dist, q_g_dist], axis=1), np.concatenate([q_g_dist.T, g_g_dist], axis=1)], axis=0) original_dist = np.power(original_dist, 2).astype(np.float32) original_dist = np.transpose(1. * original_dist / np.max(original_dist, axis=0)) V = np.zeros_like(original_dist).astype(np.float32) initial_rank = np.argsort(original_dist).astype(np.int32) query_num = q_g_dist.shape[0] gallery_num = q_g_dist.shape[0] + q_g_dist.shape[1] all_num = gallery_num for i in range(all_num): # k-reciprocal neighbors forward_k_neigh_index = initial_rank[i, :k1 + 1] backward_k_neigh_index = initial_rank[forward_k_neigh_index, :k1 + 1] fi = np.where(backward_k_neigh_index == i)[0] k_reciprocal_index = forward_k_neigh_index[fi] k_reciprocal_expansion_index = k_reciprocal_index for j in range(len(k_reciprocal_index)): candidate = k_reciprocal_index[j] candidate_forward_k_neigh_index = initial_rank[candidate, :int(np.around(k1 / 2.)) + 1] candidate_backward_k_neigh_index = initial_rank[candidate_forward_k_neigh_index, :int(np.around(k1 / 2.)) + 1] fi_candidate = np.where(candidate_backward_k_neigh_index == candidate)[0] candidate_k_reciprocal_index = candidate_forward_k_neigh_index[fi_candidate] if len(np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)) > 2. / 3 * len( candidate_k_reciprocal_index): k_reciprocal_expansion_index = np.append(k_reciprocal_expansion_index, candidate_k_reciprocal_index) k_reciprocal_expansion_index = np.unique(k_reciprocal_expansion_index) weight = np.exp(-original_dist[i, k_reciprocal_expansion_index]) V[i, k_reciprocal_expansion_index] = 1. * weight / np.sum(weight) original_dist = original_dist[:query_num, ] if k2 != 1: V_qe = np.zeros_like(V, dtype=np.float32) for i in range(all_num): V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0) V = V_qe del V_qe del initial_rank invIndex = [] for i in range(gallery_num): invIndex.append(np.where(V[:, i] != 0)[0]) jaccard_dist = np.zeros_like(original_dist, dtype=np.float32) for i in range(query_num): temp_min = np.zeros(shape=[1, gallery_num], dtype=np.float32) indNonZero = np.where(V[i, :] != 0)[0] indImages = [invIndex[ind] for ind in indNonZero] for j in range(len(indNonZero)): temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]) jaccard_dist[i] = 1 - temp_min / (2. - temp_min) final_dist = jaccard_dist * (1 - lambda_value) + original_dist * lambda_value del original_dist del V del jaccard_dist final_dist = final_dist[:query_num, query_num:] return final_dist def extract_reid_feature(data_loader, model, device, normalize, print_freq=200): """Extract feature for person ReID. If `normalize` is True, `cosine` distance will be employed as distance metric, otherwise `euclidean` distance. """ batch_time = AverageMeter('Time', ':6.3f') progress = ProgressMeter( len(data_loader), [batch_time], prefix='Collect feature: ') # switch to eval mode model.eval() feature_dict = dict() with torch.no_grad(): end = time.time() for i, (images_batch, filenames_batch, _, _) in enumerate(data_loader): images_batch = images_batch.to(device) features_batch = model(images_batch) if normalize: features_batch = F.normalize(features_batch) for filename, feature in zip(filenames_batch, features_batch): feature_dict[filename] = feature # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: progress.display(i) return feature_dict def pairwise_distance(feature_dict, query, gallery): """Compute pairwise distance between two sets of features""" # concat features and convert to pytorch tensor # we compute pairwise distance metric on cpu because it may require a large amount of GPU memory, if you are using # gpu with a larger capacity, it's faster to calculate on gpu x = torch.cat([feature_dict[f].unsqueeze(0) for f, _, _ in query], dim=0).cpu() y = torch.cat([feature_dict[f].unsqueeze(0) for f, _, _ in gallery], dim=0).cpu() m, n = x.size(0), y.size(0) # flatten x = x.view(m, -1) y = y.view(n, -1) # compute dist_mat dist_mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() - \ 2 * torch.matmul(x, y.t()) return dist_mat def evaluate_all(dist_mat, query, gallery, cmc_topk=(1, 5, 10), cmc_flag=False): """Compute CMC score, mAP and return""" query_ids = [pid for _, pid, _ in query] gallery_ids = [pid for _, pid, _ in gallery] query_cams = [cid for _, _, cid in query] gallery_cams = [cid for _, _, cid in gallery] # Compute mean AP mAP = mean_ap(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams) print('Mean AP: {:4.1%}'.format(mAP)) if not cmc_flag: return mAP cmc_configs = { 'config': dict(separate_camera_set=False, single_gallery_shot=False, first_match_break=True) } cmc_scores = {name: cmc(dist_mat, query_ids, gallery_ids, query_cams, gallery_cams, **params) for name, params in cmc_configs.items()} print('CMC Scores:') for k in cmc_topk: print(' top-{:<4}{:12.1%}'.format(k, cmc_scores['config'][k - 1])) return cmc_scores['config'][0], mAP def validate(val_loader, model, query, gallery, device, criterion='cosine', cmc_flag=False, rerank=False): assert criterion in ['cosine', 'euclidean'] # when criterion == 'cosine', normalize feature of single image into unit norm normalize = (criterion == 'cosine') feature_dict = extract_reid_feature(val_loader, model, device, normalize) dist_mat = pairwise_distance(feature_dict, query, gallery) results = evaluate_all(dist_mat, query=query, gallery=gallery, cmc_flag=cmc_flag) if not rerank: return results # apply person re-ranking print('Applying person re-ranking') dist_mat_query = pairwise_distance(feature_dict, query, query) dist_mat_gallery = pairwise_distance(feature_dict, gallery, gallery) dist_mat = re_ranking(dist_mat, dist_mat_query, dist_mat_gallery) return evaluate_all(dist_mat, query=query, gallery=gallery, cmc_flag=cmc_flag) # location parameters for visualization GRID_SPACING = 10 QUERY_EXTRA_SPACING = 90 # border width BW = 5 GREEN = (0, 255, 0) RED = (0, 0, 255) def visualize_ranked_results(data_loader, model, query, gallery, device, visualize_dir, criterion='cosine', rerank=False, width=128, height=256, topk=10): """Visualize ranker results. We first compute pair-wise distance between query images and gallery images. Then for every query image, `topk` gallery images with least distance between given query image are selected. We plot the query image and selected gallery images together. A green border denotes a match, and a red one denotes a mis-match. """ assert criterion in ['cosine', 'euclidean'] normalize = (criterion == 'cosine') # compute pairwise distance matrix feature_dict = extract_reid_feature(data_loader, model, device, normalize) dist_mat = pairwise_distance(feature_dict, query, gallery) if rerank: dist_mat_query = pairwise_distance(feature_dict, query, query) dist_mat_gallery = pairwise_distance(feature_dict, gallery, gallery) dist_mat = re_ranking(dist_mat, dist_mat_query, dist_mat_gallery) # make dir if not exists os.makedirs(visualize_dir, exist_ok=True) dist_mat = dist_mat.numpy() num_q, num_g = dist_mat.shape print('query images: {}'.format(num_q)) print('gallery images: {}'.format(num_g)) assert num_q == len(query) assert num_g == len(gallery) # start visualizing import cv2 sorted_idxes = np.argsort(dist_mat, axis=1) for q_idx in range(num_q): q_img_path, q_pid, q_cid = query[q_idx] q_img = cv2.imread(q_img_path) q_img = cv2.resize(q_img, (width, height)) # use black border to denote query image q_img = cv2.copyMakeBorder( q_img, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=(0, 0, 0) ) q_img = cv2.resize(q_img, (width, height)) num_cols = topk + 1 grid_img = 255 * np.ones( (height, num_cols * width + topk * GRID_SPACING + QUERY_EXTRA_SPACING, 3), dtype=np.uint8 ) grid_img[:, :width, :] = q_img # collect top-k gallery images with smallest distance rank_idx = 1 for g_idx in sorted_idxes[q_idx, :]: g_img_path, g_pid, g_cid = gallery[g_idx] invalid = (q_pid == g_pid) & (q_cid == g_cid) if not invalid: matched = (g_pid == q_pid) border_color = GREEN if matched else RED g_img = cv2.imread(g_img_path) g_img = cv2.resize(g_img, (width, height)) g_img = cv2.copyMakeBorder( g_img, BW, BW, BW, BW, cv2.BORDER_CONSTANT, value=border_color ) g_img = cv2.resize(g_img, (width, height)) start = rank_idx * width + rank_idx * GRID_SPACING + QUERY_EXTRA_SPACING end = (rank_idx + 1) * width + rank_idx * GRID_SPACING + QUERY_EXTRA_SPACING grid_img[:, start:end, :] = g_img rank_idx += 1 if rank_idx > topk: break save_path = osp.basename(osp.splitext(q_img_path)[0]) cv2.imwrite(osp.join(visualize_dir, save_path + '.jpg'), grid_img) if (q_idx + 1) % 100 == 0: print('Visualize {}/{}'.format(q_idx + 1, num_q)) print('Visualization process is done, ranked results are saved to {}'.format(visualize_dir)) ================================================ FILE: tllib/utils/scheduler.py ================================================ """ Modified from https://github.com/yxgeee/MMT @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch from bisect import bisect_right class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): r"""Starts with a warm-up phase, then decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. milestones (list): List of epoch indices. Must be increasing. gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. warmup_factor (float): a float number :math:`k` between 0 and 1, the start learning rate of warmup phase will be set to :math:`k*initial\_lr` warmup_steps (int): number of warm-up steps. warmup_method (str): "constant" denotes a constant learning rate during warm-up phase and "linear" denotes a linear-increasing learning rate during warm-up phase. last_epoch (int): The index of last epoch. Default: -1. """ def __init__( self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, warmup_steps=500, warmup_method="linear", last_epoch=-1, ): if not list(milestones) == sorted(milestones): raise ValueError( "Milestones should be a list of" " increasing integers. Got {}", milestones, ) if warmup_method not in ("constant", "linear"): raise ValueError( "Only 'constant' or 'linear' warmup_method accepted" "got {}".format(warmup_method) ) self.milestones = milestones self.gamma = gamma self.warmup_factor = warmup_factor self.warmup_steps = warmup_steps self.warmup_method = warmup_method super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) def get_lr(self): warmup_factor = 1 if self.last_epoch < self.warmup_steps: if self.warmup_method == "constant": warmup_factor = self.warmup_factor elif self.warmup_method == "linear": alpha = float(self.last_epoch) / float(self.warmup_steps) warmup_factor = self.warmup_factor * (1 - alpha) + alpha return [ base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) for base_lr in self.base_lrs ] ================================================ FILE: tllib/vision/__init__.py ================================================ __all__ = ['datasets', 'models', 'transforms'] ================================================ FILE: tllib/vision/datasets/__init__.py ================================================ from .imagelist import ImageList from .office31 import Office31 from .officehome import OfficeHome from .visda2017 import VisDA2017 from .officecaltech import OfficeCaltech from .domainnet import DomainNet from .imagenet_r import ImageNetR from .imagenet_sketch import ImageNetSketch from .pacs import PACS from .digits import * from .aircrafts import Aircraft from .cub200 import CUB200 from .stanford_cars import StanfordCars from .stanford_dogs import StanfordDogs from .coco70 import COCO70 from .oxfordpets import OxfordIIITPets from .dtd import DTD from .oxfordflowers import OxfordFlowers102 from .patchcamelyon import PatchCamelyon from .retinopathy import Retinopathy from .eurosat import EuroSAT from .resisc45 import Resisc45 from .food101 import Food101 from .sun397 import SUN397 from .caltech101 import Caltech101 from .cifar import CIFAR10, CIFAR100 __all__ = ['ImageList', 'Office31', 'OfficeHome', "VisDA2017", "OfficeCaltech", "DomainNet", "ImageNetR", "ImageNetSketch", "Aircraft", "cub200", "StanfordCars", "StanfordDogs", "COCO70", "OxfordIIITPets", "PACS", "DTD", "OxfordFlowers102", "PatchCamelyon", "Retinopathy", "EuroSAT", "Resisc45", "Food101", "SUN397", "Caltech101", "CIFAR10", "CIFAR100"] ================================================ FILE: tllib/vision/datasets/_util.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from typing import List from torchvision.datasets.utils import download_and_extract_archive def download(root: str, file_name: str, archive_name: str, url_link: str): """ Download file from internet url link. Args: root (str) The directory to put downloaded files. file_name: (str) The name of the unzipped file. archive_name: (str) The name of archive(zipped file) downloaded. url_link: (str) The url link to download data. .. note:: If `file_name` already exists under path `root`, then it is not downloaded again. Else `archive_name` will be downloaded from `url_link` and extracted to `file_name`. """ if not os.path.exists(os.path.join(root, file_name)): print("Downloading {}".format(file_name)) # if os.path.exists(os.path.join(root, archive_name)): # os.remove(os.path.join(root, archive_name)) try: download_and_extract_archive(url_link, download_root=root, filename=archive_name, remove_finished=False) except Exception: print("Fail to download {} from url link {}".format(archive_name, url_link)) print('Please check you internet connection.' "Simply trying again may be fine.") exit(0) def check_exits(root: str, file_name: str): """Check whether `file_name` exists under directory `root`. """ if not os.path.exists(os.path.join(root, file_name)): print("Dataset directory {} not found under {}".format(file_name, root)) exit(-1) def read_list_from_file(file_name: str) -> List[str]: """Read data from file and convert each line into an element in the list""" result = [] with open(file_name, "r") as f: for line in f.readlines(): result.append(line.strip()) return result ================================================ FILE: tllib/vision/datasets/aircrafts.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class Aircraft(ImageList): """`FVGC-Aircraft `_ \ is a benchmark for the fine-grained visual categorization of aircraft. \ The dataset contains 10,200 images of aircraft, with 100 images for each \ of the 102 different aircraft variants. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. sample_rate (int): The sampling rates to sample random ``training`` images for each category. Choices include 100, 50, 30, 15. Default: 100. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: train/ test/ image_list/ train_100.txt train_50.txt train_30.txt train_15.txt test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/449157d27987463cbdb1/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/06804f17fdb947aa9401/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/164996d09cc749abbdeb/?dl=1"), ] image_list = { "train": "image_list/train_100.txt", "train100": "image_list/train_100.txt", "train50": "image_list/train_50.txt", "train30": "image_list/train_30.txt", "train15": "image_list/train_15.txt", "test": "image_list/test.txt", "test100": "image_list/test.txt", } CLASSES = ['707-320', '727-200', '737-200', '737-300', '737-400', '737-500', '737-600', '737-700', '737-800', '737-900', '747-100', '747-200', '747-300', '747-400', '757-200', '757-300', '767-200', '767-300', '767-400', '777-200', '777-300', 'A300B4', 'A310', 'A318', 'A319', 'A320', 'A321', 'A330-200', 'A330-300', 'A340-200', 'A340-300', 'A340-500', 'A340-600', 'A380', 'ATR-42', 'ATR-72', 'An-12', 'BAE 146-200', 'BAE 146-300', 'BAE-125', 'Beechcraft 1900', 'Boeing 717', 'C-130', 'C-47', 'CRJ-200', 'CRJ-700', 'CRJ-900', 'Cessna 172', 'Cessna 208', 'Cessna 525', 'Cessna 560', 'Challenger 600', 'DC-10', 'DC-3', 'DC-6', 'DC-8', 'DC-9-30', 'DH-82', 'DHC-1', 'DHC-6', 'DHC-8-100', 'DHC-8-300', 'DR-400', 'Dornier 328', 'E-170', 'E-190', 'E-195', 'EMB-120', 'ERJ 135', 'ERJ 145', 'Embraer Legacy 600', 'Eurofighter Typhoon', 'F-16A-B', 'F-A-18', 'Falcon 2000', 'Falcon 900', 'Fokker 100', 'Fokker 50', 'Fokker 70', 'Global Express', 'Gulfstream IV', 'Gulfstream V', 'Hawk T1', 'Il-76', 'L-1011', 'MD-11', 'MD-80', 'MD-87', 'MD-90', 'Metroliner', 'Model B200', 'PA-28', 'SR-20', 'Saab 2000', 'Saab 340', 'Spitfire', 'Tornado', 'Tu-134', 'Tu-154', 'Yak-42'] def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False, **kwargs): if split == 'train': list_name = 'train' + str(sample_rate) assert list_name in self.image_list data_list_file = os.path.join(root, self.image_list[list_name]) else: data_list_file = os.path.join(root, self.image_list['test']) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(Aircraft, self).__init__(root, Aircraft.CLASSES, data_list_file=data_list_file, **kwargs) ================================================ FILE: tllib/vision/datasets/caltech101.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import os from .imagelist import ImageList from ._util import download as download_data, check_exits class Caltech101(ImageList): """`The Caltech101 Dataset `_ contains objects belonging to 101 categories with about 40 to 800 images per category. Most categories have about 50 images. The size of each image is roughly 300 x 200 pixels. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/d6d4b813a800403f835e/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/ed4d0de80da246f98171/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/db1c444200a848799683/?dl=1") ] def __init__(self, root, split='train', download=True, **kwargs): classes = ['accordion', 'airplanes', 'anchor', 'ant', 'background_google', 'barrel', 'bass', 'beaver', 'binocular', 'bonsai', 'brain', 'brontosaurus', 'buddha', 'butterfly', 'camera', 'cannon', 'car_side', 'ceiling_fan', 'cellphone', 'chair', 'chandelier', 'cougar_body', 'cougar_face', 'crab', 'crayfish', 'crocodile', 'crocodile_head', 'cup', 'dalmatian', 'dollar_bill', 'dolphin', 'dragonfly', 'electric_guitar', 'elephant', 'emu', 'euphonium', 'ewer', 'faces', 'faces_easy', 'ferry', 'flamingo', 'flamingo_head', 'garfield', 'gerenuk', 'gramophone', 'grand_piano', 'hawksbill', 'headphone', 'hedgehog', 'helicopter', 'ibis', 'inline_skate', 'joshua_tree', 'kangaroo', 'ketch', 'lamp', 'laptop', 'leopards', 'llama', 'lobster', 'lotus', 'mandolin', 'mayfly', 'menorah', 'metronome', 'minaret', 'motorbikes', 'nautilus', 'octopus', 'okapi', 'pagoda', 'panda', 'pigeon', 'pizza', 'platypus', 'pyramid', 'revolver', 'rhino', 'rooster', 'saxophone', 'schooner', 'scissors', 'scorpion', 'sea_horse', 'snoopy', 'soccer_ball', 'stapler', 'starfish', 'stegosaurus', 'stop_sign', 'strawberry', 'sunflower', 'tick', 'trilobite', 'umbrella', 'watch', 'water_lilly', 'wheelchair', 'wild_cat', 'windsor_chair', 'wrench', 'yin_yang'] if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(Caltech101, self).__init__(root, classes, os.path.join(root, 'image_list', '{}.txt'.format(split)), **kwargs) ================================================ FILE: tllib/vision/datasets/cifar.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from torchvision.datasets.cifar import CIFAR10 as CIFAR10Base, CIFAR100 as CIFAR100Base class CIFAR10(CIFAR10Base): """ `CIFAR10 `_ Dataset. """ def __init__(self, root, split='train', transform=None, download=True): super(CIFAR10, self).__init__(root, train=split == 'train', transform=transform, download=download) self.num_classes = 10 class CIFAR100(CIFAR100Base): """ `CIFAR100 `_ Dataset. """ def __init__(self, root, split='train', transform=None, download=True): super(CIFAR100, self).__init__(root, train=split == 'train', transform=transform, download=download) self.num_classes = 100 ================================================ FILE: tllib/vision/datasets/coco70.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class COCO70(ImageList): """COCO-70 dataset is a large-scale classification dataset (1000 images per class) created from `COCO `_ Dataset. It is used to explore the effect of fine-tuning with a large amount of data. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. sample_rate (int): The sampling rates to sample random ``training`` images for each category. Choices include 100, 50, 30, 15. Default: 100. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: train/ test/ image_list/ train_100.txt train_50.txt train_30.txt train_15.txt test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/b008c0d823ad488c8be1/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/75a895576d5e4e59a88d/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/ec6e45bc830d42f0924a/?dl=1"), ] image_list = { "train": "image_list/train_100.txt", "train100": "image_list/train_100.txt", "train50": "image_list/train_50.txt", "train30": "image_list/train_30.txt", "train15": "image_list/train_15.txt", "test": "image_list/test.txt", "test100": "image_list/test.txt", } CLASSES =['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'skis', 'kite', 'baseball_bat', 'skateboard', 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop', 'remote', 'keyboard', 'cell_phone', 'microwave', 'oven', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'teddy_bear'] def __init__(self, root: str, split: str, sample_rate: Optional[int] =100, download: Optional[bool] = False, **kwargs): if split == 'train': list_name = 'train' + str(sample_rate) assert list_name in self.image_list data_list_file = os.path.join(root, self.image_list[list_name]) else: data_list_file = os.path.join(root, self.image_list['test']) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(COCO70, self).__init__(root, COCO70.CLASSES, data_list_file=data_list_file, **kwargs) ================================================ FILE: tllib/vision/datasets/cub200.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class CUB200(ImageList): """`Caltech-UCSD Birds-200-2011 `_ \ is a dataset for fine-grained visual recognition with 11,788 images in 200 bird species. \ It is an extended version of the CUB-200 dataset, roughly doubling the number of images. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. sample_rate (int): The sampling rates to sample random ``training`` images for each category. Choices include 100, 50, 30, 15. Default: 100. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: train/ test/ image_list/ train_100.txt train_50.txt train_30.txt train_15.txt test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/c2a5952eb18b466b9fb0/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/63db4c49b57b43198b95/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/72e95cccdcaf4b42b4eb/?dl=1"), ] image_list = { "train": "image_list/train_100.txt", "train100": "image_list/train_100.txt", "train50": "image_list/train_50.txt", "train30": "image_list/train_30.txt", "train15": "image_list/train_15.txt", "test": "image_list/test.txt", "test100": "image_list/test.txt", } CLASSES = ['001.Black_footed_Albatross', '002.Laysan_Albatross', '003.Sooty_Albatross', '004.Groove_billed_Ani', '005.Crested_Auklet', '006.Least_Auklet', '007.Parakeet_Auklet', '008.Rhinoceros_Auklet', '009.Brewer_Blackbird', '010.Red_winged_Blackbird', '011.Rusty_Blackbird', '012.Yellow_headed_Blackbird', '013.Bobolink', '014.Indigo_Bunting', '015.Lazuli_Bunting', '016.Painted_Bunting', '017.Cardinal', '018.Spotted_Catbird', '019.Gray_Catbird', '020.Yellow_breasted_Chat', '021.Eastern_Towhee', '022.Chuck_will_Widow', '023.Brandt_Cormorant', '024.Red_faced_Cormorant', '025.Pelagic_Cormorant', '026.Bronzed_Cowbird', '027.Shiny_Cowbird', '028.Brown_Creeper', '029.American_Crow', '030.Fish_Crow', '031.Black_billed_Cuckoo', '032.Mangrove_Cuckoo', '033.Yellow_billed_Cuckoo', '034.Gray_crowned_Rosy_Finch', '035.Purple_Finch', '036.Northern_Flicker', '037.Acadian_Flycatcher', '038.Great_Crested_Flycatcher', '039.Least_Flycatcher', '040.Olive_sided_Flycatcher', '041.Scissor_tailed_Flycatcher', '042.Vermilion_Flycatcher', '043.Yellow_bellied_Flycatcher', '044.Frigatebird', '045.Northern_Fulmar', '046.Gadwall', '047.American_Goldfinch', '048.European_Goldfinch', '049.Boat_tailed_Grackle', '050.Eared_Grebe', '051.Horned_Grebe', '052.Pied_billed_Grebe', '053.Western_Grebe', '054.Blue_Grosbeak', '055.Evening_Grosbeak', '056.Pine_Grosbeak', '057.Rose_breasted_Grosbeak', '058.Pigeon_Guillemot', '059.California_Gull', '060.Glaucous_winged_Gull', '061.Heermann_Gull', '062.Herring_Gull', '063.Ivory_Gull', '064.Ring_billed_Gull', '065.Slaty_backed_Gull', '066.Western_Gull', '067.Anna_Hummingbird', '068.Ruby_throated_Hummingbird', '069.Rufous_Hummingbird', '070.Green_Violetear', '071.Long_tailed_Jaeger', '072.Pomarine_Jaeger', '073.Blue_Jay', '074.Florida_Jay', '075.Green_Jay', '076.Dark_eyed_Junco', '077.Tropical_Kingbird', '078.Gray_Kingbird', '079.Belted_Kingfisher', '080.Green_Kingfisher', '081.Pied_Kingfisher', '082.Ringed_Kingfisher', '083.White_breasted_Kingfisher', '084.Red_legged_Kittiwake', '085.Horned_Lark', '086.Pacific_Loon', '087.Mallard', '088.Western_Meadowlark', '089.Hooded_Merganser', '090.Red_breasted_Merganser', '091.Mockingbird', '092.Nighthawk', '093.Clark_Nutcracker', '094.White_breasted_Nuthatch', '095.Baltimore_Oriole', '096.Hooded_Oriole', '097.Orchard_Oriole', '098.Scott_Oriole', '099.Ovenbird', '100.Brown_Pelican', '101.White_Pelican', '102.Western_Wood_Pewee', '103.Sayornis', '104.American_Pipit', '105.Whip_poor_Will', '106.Horned_Puffin', '107.Common_Raven', '108.White_necked_Raven', '109.American_Redstart', '110.Geococcyx', '111.Loggerhead_Shrike', '112.Great_Grey_Shrike', '113.Baird_Sparrow', '114.Black_throated_Sparrow', '115.Brewer_Sparrow', '116.Chipping_Sparrow', '117.Clay_colored_Sparrow', '118.House_Sparrow', '119.Field_Sparrow', '120.Fox_Sparrow', '121.Grasshopper_Sparrow', '122.Harris_Sparrow', '123.Henslow_Sparrow', '124.Le_Conte_Sparrow', '125.Lincoln_Sparrow', '126.Nelson_Sharp_tailed_Sparrow', '127.Savannah_Sparrow', '128.Seaside_Sparrow', '129.Song_Sparrow', '130.Tree_Sparrow', '131.Vesper_Sparrow', '132.White_crowned_Sparrow', '133.White_throated_Sparrow', '134.Cape_Glossy_Starling', '135.Bank_Swallow', '136.Barn_Swallow', '137.Cliff_Swallow', '138.Tree_Swallow', '139.Scarlet_Tanager', '140.Summer_Tanager', '141.Artic_Tern', '142.Black_Tern', '143.Caspian_Tern', '144.Common_Tern', '145.Elegant_Tern', '146.Forsters_Tern', '147.Least_Tern', '148.Green_tailed_Towhee', '149.Brown_Thrasher', '150.Sage_Thrasher', '151.Black_capped_Vireo', '152.Blue_headed_Vireo', '153.Philadelphia_Vireo', '154.Red_eyed_Vireo', '155.Warbling_Vireo', '156.White_eyed_Vireo', '157.Yellow_throated_Vireo', '158.Bay_breasted_Warbler', '159.Black_and_white_Warbler', '160.Black_throated_Blue_Warbler', '161.Blue_winged_Warbler', '162.Canada_Warbler', '163.Cape_May_Warbler', '164.Cerulean_Warbler', '165.Chestnut_sided_Warbler', '166.Golden_winged_Warbler', '167.Hooded_Warbler', '168.Kentucky_Warbler', '169.Magnolia_Warbler', '170.Mourning_Warbler', '171.Myrtle_Warbler', '172.Nashville_Warbler', '173.Orange_crowned_Warbler', '174.Palm_Warbler', '175.Pine_Warbler', '176.Prairie_Warbler', '177.Prothonotary_Warbler', '178.Swainson_Warbler', '179.Tennessee_Warbler', '180.Wilson_Warbler', '181.Worm_eating_Warbler', '182.Yellow_Warbler', '183.Northern_Waterthrush', '184.Louisiana_Waterthrush', '185.Bohemian_Waxwing', '186.Cedar_Waxwing', '187.American_Three_toed_Woodpecker', '188.Pileated_Woodpecker', '189.Red_bellied_Woodpecker', '190.Red_cockaded_Woodpecker', '191.Red_headed_Woodpecker', '192.Downy_Woodpecker', '193.Bewick_Wren', '194.Cactus_Wren', '195.Carolina_Wren', '196.House_Wren', '197.Marsh_Wren', '198.Rock_Wren', '199.Winter_Wren', '200.Common_Yellowthroat'] def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False, **kwargs): if split == 'train': list_name = 'train' + str(sample_rate) assert list_name in self.image_list data_list_file = os.path.join(root, self.image_list[list_name]) else: data_list_file = os.path.join(root, self.image_list['test']) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(CUB200, self).__init__(root, CUB200.CLASSES, data_list_file=data_list_file, **kwargs) ================================================ FILE: tllib/vision/datasets/digits.py ================================================ """ @author: Junguang Jiang, Baixu Chen @contact: JiangJunguang1123@outlook.com, cbx_99_hasta@outlook.com """ import os from typing import Optional, Tuple, Any from .imagelist import ImageList from ._util import download as download_data, check_exits class MNIST(ImageList): """`MNIST `_ Dataset. Args: root (str): Root directory of dataset where ``MNIST/processed/training.pt`` and ``MNIST/processed/test.pt`` exist. mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``. Default: ``"L"``` split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/16feadf7fb3641c2be9a/?dl=1"), ("mnist_train_image", "mnist_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/c93080af28e54559aeeb/?dl=1"), # ("mnist_test_image", "mnist_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/c93080af28e54559aeeb/?dl=1") ] image_list = { "train": "image_list/mnist_train.txt", "test": "image_list/mnist_test.txt" } CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] def __init__(self, root, mode="L", split='train', download: Optional[bool] = True, **kwargs): assert split in ['train', 'test'] data_list_file = os.path.join(root, self.image_list[split]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) assert mode in ['L', 'RGB'] self.mode = mode super(MNIST, self).__init__(root, MNIST.CLASSES, data_list_file=data_list_file, **kwargs) def __getitem__(self, index: int) -> Tuple[Any, int]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path, target = self.samples[index] img = self.loader(path).convert(self.mode) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target @classmethod def get_classes(self): return MNIST.CLASSES class USPS(ImageList): """`USPS `_ Dataset. The data-format is : [label [index:value ]*256 \\n] * num_lines, where ``label`` lies in ``[1, 10]``. The value for each pixel lies in ``[-1, 1]``. Here we transform the ``label`` into ``[0, 9]`` and make pixel values in ``[0, 255]``. Args: root (str): Root directory of dataset to store``USPS`` data files. mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``. Default: ``"L"``` split (str, optional): The dataset split, supports ``train``, or ``test``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/721ceaf3c031413cb62f/?dl=1"), ("usps_train_image", "usps_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/c5bd329a00fb4dc79608/?dl=1"), # ("usps_test_image", "usps_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/c5bd329a00fb4dc79608/?dl=1") ] image_list = { "train": "image_list/usps_train.txt", "test": "image_list/usps_test.txt" } CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] def __init__(self, root, mode="L", split='train', download: Optional[bool] = True, **kwargs): assert split in ['train', 'test'] data_list_file = os.path.join(root, self.image_list[split]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) assert mode in ['L', 'RGB'] self.mode = mode super(USPS, self).__init__(root, USPS.CLASSES, data_list_file=data_list_file, **kwargs) def __getitem__(self, index: int) -> Tuple[Any, int]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path, target = self.samples[index] img = self.loader(path).convert(self.mode) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target class SVHN(ImageList): """`SVHN `_ Dataset. Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which expect the class labels to be in the range `[0, C-1]` .. warning:: This class needs `scipy `_ to load data from `.mat` format. Args: root (str): Root directory of dataset where directory ``SVHN`` exists. mode (str): The channel mode for image. Choices includes ``"L"```, ``"RGB"``. Default: ``"RGB"``` split (str, optional): The dataset split, supports ``train``, or ``test``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/12b35fb08f8049f98362/?dl=1"), ("svhn_image", "svhn_image.tar.gz", "https://cloud.tsinghua.edu.cn/f/cc02de6cf81543378cce/?dl=1") ] image_list = "image_list/svhn_balanced.txt" # image_list = "image_list/svhn.txt" CLASSES = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] def __init__(self, root, mode="L", download: Optional[bool] = True, **kwargs): data_list_file = os.path.join(root, self.image_list) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) assert mode in ['L', 'RGB'] self.mode = mode super(SVHN, self).__init__(root, SVHN.CLASSES, data_list_file=data_list_file, **kwargs) def __getitem__(self, index: int) -> Tuple[Any, int]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path, target = self.samples[index] img = self.loader(path).convert(self.mode) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target class MNISTRGB(MNIST): def __init__(self, root, **kwargs): super(MNISTRGB, self).__init__(root, mode='RGB', **kwargs) class USPSRGB(USPS): def __init__(self, root, **kwargs): super(USPSRGB, self).__init__(root, mode='RGB', **kwargs) class SVHNRGB(SVHN): def __init__(self, root, **kwargs): super(SVHNRGB, self).__init__(root, mode='RGB', **kwargs) ================================================ FILE: tllib/vision/datasets/domainnet.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class DomainNet(ImageList): """`DomainNet `_ (cleaned version, recommended) See `Moment Matching for Multi-Source Domain Adaptation `_ for details. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'c'``:clipart, \ ``'i'``: infograph, ``'p'``: painting, ``'q'``: quickdraw, ``'r'``: real, ``'s'``: sketch split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: clipart/ infograph/ painting/ quickdraw/ real/ sketch/ image_list/ clipart.txt ... """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/bf0fe327e4b046eb89ba/?dl=1"), ("clipart", "clipart.tgz", "https://cloud.tsinghua.edu.cn/f/f0515164a4864220b98b/?dl=1"), ("infograph", "infograph.tgz", "https://cloud.tsinghua.edu.cn/f/98b19d5fc9884109a9cb/?dl=1"), ("painting", "painting.tgz", "https://cloud.tsinghua.edu.cn/f/11285ce9fbd34bb7b28c/?dl=1"), ("quickdraw", "quickdraw.tgz", "https://cloud.tsinghua.edu.cn/f/6faa9efb498b494abf66/?dl=1"), ("real", "real.tgz", "https://cloud.tsinghua.edu.cn/f/17a101842c564959b525/?dl=1"), ("sketch", "sketch.tgz", "https://cloud.tsinghua.edu.cn/f/b305add26e9d47349495/?dl=1"), ] image_list = { "c": "clipart", "i": "infograph", "p": "painting", "q": "quickdraw", "r": "real", "s": "sketch", } CLASSES = ['aircraft_carrier', 'airplane', 'alarm_clock', 'ambulance', 'angel', 'animal_migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball_bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday_cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling_fan', 'cello', 'cell_phone', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee_cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise_ship', 'cup', 'diamond', 'dishwasher', 'diving_board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire_hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip_flops', 'floor_lamp', 'flower', 'flying_saucer', 'foot', 'fork', 'frog', 'frying_pan', 'garden', 'garden_hose', 'giraffe', 'goatee', 'golf_club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey_puck', 'hockey_stick', 'horse', 'hospital', 'hot_air_balloon', 'hot_dog', 'hot_tub', 'hourglass', 'house', 'house_plant', 'hurricane', 'ice_cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'knife', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light_bulb', 'lighter', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paintbrush', 'paint_can', 'palm_tree', 'panda', 'pants', 'paper_clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup_truck', 'picture_frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police_car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power_outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote_control', 'rhinoceros', 'rifle', 'river', 'roller_coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school_bus', 'scissors', 'scorpion', 'screwdriver', 'sea_turtle', 'see_saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping_bag', 'smiley_face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer_ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop_sign', 'stove', 'strawberry', 'streetlight', 'string_bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing_set', 'sword', 'syringe', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis_racquet', 'tent', 'The_Eiffel_Tower', 'The_Great_Wall_of_China', 'The_Mona_Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic_light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 't-shirt', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing_machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine_bottle', 'wine_glass', 'wristwatch', 'yoga', 'zebra', 'zigzag'] def __init__(self, root: str, task: str, split: Optional[str] = 'train', download: Optional[float] = False, **kwargs): assert task in self.image_list assert split in ['train', 'test'] data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split)) print("loading {}".format(data_list_file)) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda args: check_exits(root, args[0]), self.download_list)) super(DomainNet, self).__init__(root, DomainNet.CLASSES, data_list_file=data_list_file, **kwargs) @classmethod def domains(cls): return list(cls.image_list.keys()) ================================================ FILE: tllib/vision/datasets/dtd.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from .imagelist import ImageList from ._util import download as download_data, check_exits class DTD(ImageList): """ `The Describable Textures Dataset (DTD) `_ is an \ evolving collection of textural images in the wild, annotated with a series of human-centric attributes, \ inspired by the perceptual properties of textures. \ The task consists in classifying images of textural patterns (47 classes, with 120 training images each). \ Some of the textures are banded, bubbly, meshed, lined, or porous. \ The image size ranges between 300x300 and 640x640 pixels. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/2218bfa61bac46539dd7/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/08fd47d35fc94f36a508/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/15873fe162c343cca8ed/?dl=1"), ("validation", "validation.tgz", "https://cloud.tsinghua.edu.cn/f/75c9ab22ebea4c3b87e7/?dl=1"), ] CLASSES = ['banded', 'blotchy', 'braided', 'bubbly', 'bumpy', 'chequered', 'cobwebbed', 'cracked', 'crosshatched', 'crystalline', 'dotted', 'fibrous', 'flecked', 'freckled', 'frilly', 'gauzy', 'grid', 'grooved', 'honeycombed', 'interlaced', 'knitted', 'lacelike', 'lined', 'marbled', 'matted', 'meshed', 'paisley', 'perforated', 'pitted', 'pleated', 'polka-dotted', 'porous', 'potholed', 'scaly', 'smeared', 'spiralled', 'sprinkled', 'stained', 'stratified', 'striped', 'studded', 'swirly', 'veined', 'waffled', 'woven', 'wrinkled', 'zigzagged'] def __init__(self, root, split, download=False, **kwargs): if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(DTD, self).__init__(root, DTD.CLASSES, os.path.join(root, "image_list", "{}.txt".format(split)), **kwargs) ================================================ FILE: tllib/vision/datasets/eurosat.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from .imagelist import ImageList from ._util import download as download_data, check_exits class EuroSAT(ImageList): """ `EuroSAT `_ dataset consists in classifying \ Sentinel-2 satellite images into 10 different types of land use (Residential, \ Industrial, River, Highway, etc). \ The spatial resolution corresponds to 10 meters per pixel, and the image size \ is 64x64 pixels. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ CLASSES =['AnnualCrop', 'Forest', 'HerbaceousVegetation', 'Highway', 'Industrial', 'Pasture', 'PermanentCrop', 'Residential', 'River', 'SeaLake'] def __init__(self, root, split='train', download=False, **kwargs): if download: download_data(root, "eurosat", "eurosat.tgz", "https://cloud.tsinghua.edu.cn/f/9983d7ab86184d74bb17/?dl=1") else: check_exits(root, "eurosat") split = 'train[:21600]' if split == 'train' else 'train[21600:]' root = os.path.join(root, "eurosat") super(EuroSAT, self).__init__(root, EuroSAT.CLASSES, os.path.join(root, "imagelist", "{}.txt".format(split)), **kwargs) ================================================ FILE: tllib/vision/datasets/food101.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from torchvision.datasets.folder import ImageFolder import os.path as osp from ._util import download as download_data, check_exits class Food101(ImageFolder): """`Food-101 `_ is a dataset for fine-grained visual recognition with 101,000 images in 101 food categories. Args: root (str): Root directory of dataset. split (str, optional): The dataset split, supports ``train``, or ``test``. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. .. note:: In `root`, there will exist following files after downloading. :: train/ test/ """ download_list = [ ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/1d7bd727cc1e4ce2bef5/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/7e11992d7495417db32b/?dl=1") ] def __init__(self, root, split='train', transform=None, download=True): if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(Food101, self).__init__(osp.join(root, split), transform=transform) self.num_classes = 101 ================================================ FILE: tllib/vision/datasets/imagelist.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os import warnings from typing import Optional, Callable, Tuple, Any, List, Iterable import bisect from torch.utils.data.dataset import Dataset, T_co, IterableDataset import torchvision.datasets as datasets from torchvision.datasets.folder import default_loader class ImageList(datasets.VisionDataset): """A generic Dataset class for image classification Args: root (str): Root directory of dataset classes (list[str]): The names of all the classes data_list_file (str): File to read the image list from. transform (callable, optional): A function/transform that takes in an PIL image \ and returns a transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `data_list_file`, each line has 2 values in the following format. :: source_dir/dog_xxx.png 0 source_dir/cat_123.png 1 target_dir/dog_xxy.png 0 target_dir/cat_nsdf3.png 1 The first value is the relative path of an image, and the second value is the label of the corresponding image. If your data_list_file has different formats, please over-ride :meth:`~ImageList.parse_data_file`. """ def __init__(self, root: str, classes: List[str], data_list_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): super().__init__(root, transform=transform, target_transform=target_transform) self.samples = self.parse_data_file(data_list_file) self.targets = [s[1] for s in self.samples] self.classes = classes self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} self.loader = default_loader self.data_list_file = data_list_file def __getitem__(self, index: int) -> Tuple[Any, int]: """ Args: index (int): Index return (tuple): (image, target) where target is index of the target class. """ path, target = self.samples[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.samples) def parse_data_file(self, file_name: str) -> List[Tuple[str, int]]: """Parse file to data list Args: file_name (str): The path of data file return (list): List of (image path, class_index) tuples """ with open(file_name, "r") as f: data_list = [] for line in f.readlines(): split_line = line.split() target = split_line[-1] path = ' '.join(split_line[:-1]) if not os.path.isabs(path): path = os.path.join(self.root, path) target = int(target) data_list.append((path, target)) return data_list @property def num_classes(self) -> int: """Number of classes""" return len(self.classes) @classmethod def domains(cls): """All possible domain in this dataset""" raise NotImplemented class MultipleDomainsDataset(Dataset[T_co]): r"""Dataset as a concatenation of multiple datasets. This class is useful to assemble different existing datasets. Args: datasets (sequence): List of datasets to be concatenated """ datasets: List[Dataset[T_co]] cumulative_sizes: List[int] @staticmethod def cumsum(sequence): r, s = [], 0 for e in sequence: l = len(e) r.append(l + s) s += l return r def __init__(self, domains: Iterable[Dataset], domain_names: Iterable[str], domain_ids) -> None: super(MultipleDomainsDataset, self).__init__() # Cannot verify that datasets is Sized assert len(domains) > 0, 'datasets should not be an empty iterable' # type: ignore[arg-type] self.datasets = self.domains = list(domains) for d in self.domains: assert not isinstance(d, IterableDataset), "MultipleDomainsDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum(self.domains) self.domain_names = domain_names self.domain_ids = domain_ids def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): if idx < 0: if -idx > len(self): raise ValueError("absolute value of index should not exceed dataset length") idx = len(self) + idx dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] return self.domains[dataset_idx][sample_idx] + (self.domain_ids[dataset_idx],) @property def cummulative_sizes(self): warnings.warn("cummulative_sizes attribute is renamed to " "cumulative_sizes", DeprecationWarning, stacklevel=2) return self.cumulative_sizes ================================================ FILE: tllib/vision/datasets/imagenet_r.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import os from .imagelist import ImageList from ._util import download as download_data, check_exits class ImageNetR(ImageList): """ImageNet-R Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \ ``'D'``: dslr and ``'W'``: webcam. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: You need to put ``train`` directory of ImageNet-1K and ``imagenet_r`` directory of ImageNet-R manually in `root` directory. DALIB will only download ImageList automatically. In `root`, there will exist following files after preparing. :: train/ n02128385/ ... val/ imagenet-r/ n02128385/ image_list/ imagenet-train.txt imagenet-r.txt art.txt ... """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/7786eabd3565409c8c33/?dl=1"), ] image_list = { "IN": "image_list/imagenet-train.txt", "IN-val": "image_list/imagenet-val.txt", "INR": "image_list/imagenet-r.txt", "art": "art.txt", "embroidery": "embroidery.txt", "misc": "misc.txt", "sculpture": "sculpture.txt", "tattoo": "tattoo.txt", "cartoon": "cartoon.txt", "graffiti": "graffiti.txt", "origami": "origami.txt", "sketch": "sketch.txt", "toy": "toy.txt", "deviantart": "deviantart.txt", "graphic": "graphic.txt", "painting": "painting.txt", "sticker": "sticker.txt", "videogame": "videogame.txt" } CLASSES = ['n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214' , 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'] def __init__(self, root: str, task: str, split: Optional[str] = 'all', download: Optional[bool] = True, **kwargs): assert task in self.image_list assert split in ["train", "val", "all"] if task == "IN" and split == "val": task = "IN-val" data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(ImageNetR, self).__init__(root, ImageNetR.CLASSES, data_list_file=data_list_file, **kwargs) @classmethod def domains(cls): return list(cls.image_list.keys()) ================================================ FILE: tllib/vision/datasets/imagenet_sketch.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import os from torchvision.datasets.imagenet import ImageNet from .imagelist import ImageList from ._util import download as download_data, check_exits class ImageNetSketch(ImageList): """ImageNet-Sketch Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \ ``'D'``: dslr and ``'W'``: webcam. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: You need to put ``train`` directory, ``metabin`` of ImageNet-1K and ``sketch`` directory of ImageNet-Sketch manually in `root` directory. DALIB will only download ImageList automatically. In `root`, there will exist following files after preparing. :: metabin (from ImageNet) train/ n02128385/ ... val/ sketch/ n02128385/ image_list/ imagenet-train.txt sketch.txt ... """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/7786eabd3565409c8c33/?dl=1"), ] image_list = { "IN": "image_list/imagenet-train.txt", "IN-val": "image_list/imagenet-val.txt", "sketch": "image_list/sketch.txt", } def __init__(self, root: str, task: str, split: Optional[str] = 'all', download: Optional[bool] = True, **kwargs): assert task in self.image_list assert split in ["train", "val", "all"] if task == "IN" and split == "val": task = "IN-val" data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(ImageNetSketch, self).__init__(root, ImageNet(root).classes, data_list_file=data_list_file, **kwargs) @classmethod def domains(cls): return list(cls.image_list.keys()) ================================================ FILE: tllib/vision/datasets/keypoint_detection/__init__.py ================================================ from .rendered_hand_pose import RenderedHandPose from .hand_3d_studio import Hand3DStudio, Hand3DStudioAll from .freihand import FreiHand from .surreal import SURREAL from .lsp import LSP from .human36m import Human36M ================================================ FILE: tllib/vision/datasets/keypoint_detection/freihand.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import json import time import torch import os import os.path as osp from torchvision.datasets.utils import download_and_extract_archive from ...transforms.keypoint_detection import * from .keypoint_dataset import Hand21KeypointDataset from .util import * """ General util functions. """ def _assert_exist(p): msg = 'File does not exists: %s' % p assert os.path.exists(p), msg def json_load(p): _assert_exist(p) with open(p, 'r') as fi: d = json.load(fi) return d def load_db_annotation(base_path, set_name=None): if set_name is None: # only training set annotations are released so this is a valid default choice set_name = 'training' print('Loading FreiHAND dataset index ...') t = time.time() # assumed paths to data containers k_path = os.path.join(base_path, '%s_K.json' % set_name) mano_path = os.path.join(base_path, '%s_mano.json' % set_name) xyz_path = os.path.join(base_path, '%s_xyz.json' % set_name) # load if exist K_list = json_load(k_path) mano_list = json_load(mano_path) xyz_list = json_load(xyz_path) # should have all the same length assert len(K_list) == len(mano_list), 'Size mismatch.' assert len(K_list) == len(xyz_list), 'Size mismatch.' print('Loading of %d samples done in %.2f seconds' % (len(K_list), time.time()-t)) return list(zip(K_list, mano_list, xyz_list)) def projectPoints(xyz, K): """ Project 3D coordinates into image space. """ xyz = np.array(xyz) K = np.array(K) uv = np.matmul(K, xyz.T).T return uv[:, :2] / uv[:, -1:] """ Dataset related functions. """ def db_size(set_name): """ Hardcoded size of the datasets. """ if set_name == 'training': return 32560 # number of unique samples (they exists in multiple 'versions') elif set_name == 'evaluation': return 3960 else: assert 0, 'Invalid choice.' class sample_version: gs = 'gs' # green screen hom = 'hom' # homogenized sample = 'sample' # auto colorization with sample points auto = 'auto' # auto colorization without sample points: automatic color hallucination db_size = db_size('training') @classmethod def valid_options(cls): return [cls.gs, cls.hom, cls.sample, cls.auto] @classmethod def check_valid(cls, version): msg = 'Invalid choice: "%s" (must be in %s)' % (version, cls.valid_options()) assert version in cls.valid_options(), msg @classmethod def map_id(cls, id, version): cls.check_valid(version) return id + cls.db_size*cls.valid_options().index(version) class FreiHand(Hand21KeypointDataset): """`FreiHand Dataset `_ Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``. task (str, optional): The post-processing option to create dataset. Choices include ``'gs'``: green screen \ recording, ``'auto'``: auto colorization without sample points: automatic color hallucination, \ ``'sample'``: auto colorization with sample points, ``'hom'``: homogenized, \ and ``'all'``: all hands. Default: 'all'. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`. image_size (tuple): (width, height) of the image. Default: (256, 256) heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 .. note:: In `root`, there will exist following files after downloading. :: *.json training/ evaluation/ """ def __init__(self, root, split='train', task='all', download=True, **kwargs): if download: if not osp.exists(osp.join(root, "training")) or not osp.exists(osp.join(root, "evaluation")): download_and_extract_archive("https://lmb.informatik.uni-freiburg.de/data/freihand/FreiHAND_pub_v2.zip", download_root=root, filename="FreiHAND_pub_v2.zip", remove_finished=False, extract_root=root) assert split in ['train', 'test', 'all'] self.split = split assert task in ['all', 'gs', 'auto', 'sample', 'hom'] self.task = task if task == 'all': samples = self.get_samples(root, 'gs') + self.get_samples(root, 'auto') + self.get_samples(root, 'sample') + self.get_samples(root, 'hom') else: samples = self.get_samples(root, task) random.seed(42) random.shuffle(samples) samples_len = len(samples) samples_split = min(int(samples_len * 0.2), 3200) if self.split == 'train': samples = samples[samples_split:] elif self.split == 'test': samples = samples[:samples_split] super(FreiHand, self).__init__(root, samples, **kwargs) def __getitem__(self, index): sample = self.samples[index] image_name = sample['name'] image_path = os.path.join(self.root, image_name) image = Image.open(image_path) keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3 keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2 intrinsic_matrix = np.array(sample['intrinsic_matrix']) Zc = keypoint3d_camera[:, 2] # Crop the images such that the hand is at the center of the image # The images will be 1.5 times larger than the hand # The crop process will change Xc and Yc, leaving Zc with no changes bounding_box = get_bounding_box(keypoint2d) w, h = image.size left, upper, right, lower = scale_box(bounding_box, w, h, 1.5) image, keypoint2d = crop(image, upper, left, lower - upper, right - left, keypoint2d) # Change all hands to right hands if sample['left'] is False: image, keypoint2d = hflip(image, keypoint2d) image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix) keypoint2d = data['keypoint2d'] intrinsic_matrix = data['intrinsic_matrix'] keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc) # noramlize 2D pose: visible = np.ones((self.num_keypoints, ), dtype=np.float32) visible = visible[:, np.newaxis] # 2D heatmap target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size) target = torch.from_numpy(target) target_weight = torch.from_numpy(target_weight) # normalize 3D pose: # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system # and make distance between wrist and middle finger MCP joint to be of length 1 keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :] keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2)) z = keypoint3d_n[:, 2] meta = { 'image': image_name, 'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2) 'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3) 'z': z, } return image, target, target_weight, meta def get_samples(self, root, version='gs'): set = 'training' # load annotations of this set db_data_anno = load_db_annotation(root, set) version_map = { 'gs': sample_version.gs, 'hom': sample_version.hom, 'sample': sample_version.sample, 'auto': sample_version.auto } samples = [] for idx in range(db_size(set)): image_name = os.path.join(set, 'rgb', '%08d.jpg' % sample_version.map_id(idx, version_map[version])) mask_name = os.path.join(set, 'mask', '%08d.jpg' % idx) intrinsic_matrix, mano, keypoint3d = db_data_anno[idx] keypoint2d = projectPoints(keypoint3d, intrinsic_matrix) sample = { 'name': image_name, 'mask_name': mask_name, 'keypoint2d': keypoint2d, 'keypoint3d': keypoint3d, 'intrinsic_matrix': intrinsic_matrix, 'left': False } samples.append(sample) return samples ================================================ FILE: tllib/vision/datasets/keypoint_detection/hand_3d_studio.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os import json import random from PIL import ImageFile, Image import torch import os.path as osp from .._util import download as download_data, check_exits from .keypoint_dataset import Hand21KeypointDataset from .util import * ImageFile.LOAD_TRUNCATED_IMAGES = True class Hand3DStudio(Hand21KeypointDataset): """`Hand-3d-Studio Dataset `_ Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``. task (str, optional): The task to create dataset. Choices include ``'noobject'``: only hands without objects, \ ``'object'``: only hands interacting with hands, and ``'all'``: all hands. Default: 'noobject'. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`. image_size (tuple): (width, height) of the image. Default: (256, 256) heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 .. note:: We found that the original H3D image is in high resolution while most part in an image is background, thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training. .. note:: In `root`, there will exist following files after downloading. :: H3D_crop/ annotation.json part1/ part2/ part3/ part4/ part5/ """ def __init__(self, root, split='train', task='noobject', download=True, **kwargs): assert split in ['train', 'test', 'all'] self.split = split assert task in ['noobject', 'object', 'all'] self.task = task if download: download_data(root, "H3D_crop", "H3D_crop.tar", "https://cloud.tsinghua.edu.cn/f/d4e612e44dc04d8eb01f/?dl=1") else: check_exits(root, "H3D_crop") root = osp.join(root, "H3D_crop") # load labels annotation_file = os.path.join(root, 'annotation.json') print("loading from {}".format(annotation_file)) with open(annotation_file) as f: samples = list(json.load(f)) if task == 'noobject': samples = [sample for sample in samples if int(sample['without_object']) == 1] elif task == 'object': samples = [sample for sample in samples if int(sample['without_object']) == 0] random.seed(42) random.shuffle(samples) samples_len = len(samples) samples_split = min(int(samples_len * 0.2), 3200) if split == 'train': samples = samples[samples_split:] elif split == 'test': samples = samples[:samples_split] super(Hand3DStudio, self).__init__(root, samples, **kwargs) def __getitem__(self, index): sample = self.samples[index] image_name = sample['name'] image_path = os.path.join(self.root, image_name) image = Image.open(image_path) keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3 keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2 intrinsic_matrix = np.array(sample['intrinsic_matrix']) Zc = keypoint3d_camera[:, 2] image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix) keypoint2d = data['keypoint2d'] intrinsic_matrix = data['intrinsic_matrix'] keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc) # noramlize 2D pose: visible = np.ones((self.num_keypoints, ), dtype=np.float32) visible = visible[:, np.newaxis] # 2D heatmap target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size) target = torch.from_numpy(target) target_weight = torch.from_numpy(target_weight) # normalize 3D pose: # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system # and make distance between wrist and middle finger MCP joint to be of length 1 keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :] keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2)) meta = { 'image': image_name, 'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2) 'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3) } return image, target, target_weight, meta class Hand3DStudioAll(Hand3DStudio): """ `Hand-3d-Studio Dataset `_ """ def __init__(self, root, task='all', **kwargs): super(Hand3DStudioAll, self).__init__(root, task=task, **kwargs) ================================================ FILE: tllib/vision/datasets/keypoint_detection/human36m.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os import json import tqdm from PIL import ImageFile import torch from .keypoint_dataset import Body16KeypointDataset from ...transforms.keypoint_detection import * from .util import * ImageFile.LOAD_TRUNCATED_IMAGES = True class Human36M(Body16KeypointDataset): """`Human3.6M Dataset `_ Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``. Default: ``train``. task (str, optional): Placeholder. download (bool, optional): Placeholder. transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`. image_size (tuple): (width, height) of the image. Default: (256, 256) heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 .. note:: You need to download Human36M manually. Ensure that there exist following files in the `root` directory before you using this class. :: annotations/ Human36M_subject11_joint_3d.json ... images/ .. note:: We found that the original Human3.6M image is in high resolution while most part in an image is background, thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training. In `root`, there will exist following files after crop. :: Human36M_crop/ annotations/ keypoints2d_11.json ... """ def __init__(self, root, split='train', task='all', download=True, **kwargs): assert split in ['train', 'test', 'all'] self.split = split samples = [] if self.split == 'train': parts = [1, 5, 6, 7, 8] elif self.split == 'test': parts = [9, 11] else: parts = [1, 5, 6, 7, 8, 9, 11] for part in parts: annotation_file = os.path.join(root, 'annotations/keypoints2d_{}.json'.format(part)) if not os.path.exists(annotation_file): self.preprocess(part, root) print("loading", annotation_file) with open(annotation_file) as f: samples.extend(json.load(f)) # decrease the number of test samples to decrease the time spent on test random.seed(42) if self.split == 'test': samples = random.choices(samples, k=3200) super(Human36M, self).__init__(root, samples, **kwargs) def __getitem__(self, index): sample = self.samples[index] image_name = sample['name'] image_path = os.path.join(self.root, "crop_images", image_name) image = Image.open(image_path) keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3 keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2 intrinsic_matrix = np.array(sample['intrinsic_matrix']) Zc = keypoint3d_camera[:, 2] image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix) keypoint2d = data['keypoint2d'] intrinsic_matrix = data['intrinsic_matrix'] keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc) # noramlize 2D pose: visible = np.ones((self.num_keypoints, ), dtype=np.float32) visible = visible[:, np.newaxis] # 2D heatmap target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size) target = torch.from_numpy(target) target_weight = torch.from_numpy(target_weight) # normalize 3D pose: # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system # and make distance between wrist and middle finger MCP joint to be of length 1 keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :] keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2)) meta = { 'image': image_name, 'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2) 'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3) } return image, target, target_weight, meta def preprocess(self, part, root): body_index = [3, 2, 1, 4, 5, 6, 0, 11, 8, 10, 16, 15, 14, 11, 12, 13] image_size = 512 print("preprocessing part", part) camera_json = os.path.join(root, "annotations", "Human36M_subject{}_camera.json".format(part)) data_json = os.path.join(root, "annotations", "Human36M_subject{}_data.json".format(part)) joint_3d_json = os.path.join(root, "annotations", "Human36M_subject{}_joint_3d.json".format(part)) with open(camera_json, "r") as f: cameras = json.load(f) with open(data_json, "r") as f: data = json.load(f) images = data['images'] with open(joint_3d_json, "r") as f: joints_3d = json.load(f) data = [] for i, image_data in enumerate(tqdm.tqdm(images)): # downsample if i % 5 == 0: keypoint3d = np.array(joints_3d[str(image_data["action_idx"])][str(image_data["subaction_idx"])][ str(image_data["frame_idx"])]) keypoint3d = keypoint3d[body_index, :] keypoint3d[7, :] = 0.5 * (keypoint3d[12, :] + keypoint3d[13, :]) camera = cameras[str(image_data["cam_idx"])] R, T = np.array(camera["R"]), np.array(camera['t'])[:, np.newaxis] extrinsic_matrix = np.concatenate([R, T], axis=1) keypoint3d_camera = np.matmul(extrinsic_matrix, np.hstack( (keypoint3d, np.ones((keypoint3d.shape[0], 1)))).T) # (3 x NUM_KEYPOINTS) Z_c = keypoint3d_camera[2:3, :] # 1 x NUM_KEYPOINTS f, c = np.array(camera["f"]), np.array(camera['c']) intrinsic_matrix = np.zeros((3, 3)) intrinsic_matrix[0, 0] = f[0] intrinsic_matrix[1, 1] = f[1] intrinsic_matrix[0, 2] = c[0] intrinsic_matrix[1, 2] = c[1] intrinsic_matrix[2, 2] = 1 keypoint2d = np.matmul(intrinsic_matrix, keypoint3d_camera) # (3 x NUM_KEYPOINTS) keypoint2d = keypoint2d[0: 2, :] / Z_c keypoint2d = keypoint2d.T src_image_path = os.path.join(root, "images", image_data['file_name']) tgt_image_path = os.path.join(root, "crop_images", image_data['file_name']) os.makedirs(os.path.dirname(tgt_image_path), exist_ok=True) image = Image.open(src_image_path) bounding_box = get_bounding_box(keypoint2d) w, h = image.size left, upper, right, lower = scale_box(bounding_box, w, h, 1.5) image, keypoint2d = crop(image, upper, left, lower-upper+1, right-left+1, keypoint2d) Z_c = Z_c.T # Calculate XYZ from uvz uv1 = np.concatenate([np.copy(keypoint2d), np.ones((16, 1))], axis=1) # NUM_KEYPOINTS x 3 uv1 = uv1 * Z_c # NUM_KEYPOINTS x 3 keypoint3d_camera = np.matmul(np.linalg.inv(intrinsic_matrix), uv1.T).T # resize image will change camera intrinsic matrix w, h = image.size image = image.resize((image_size, image_size)) image.save(tgt_image_path) zoom_factor = float(w) / float(image_size) keypoint2d /= zoom_factor intrinsic_matrix[0, 0] /= zoom_factor intrinsic_matrix[1, 1] /= zoom_factor intrinsic_matrix[0, 2] /= zoom_factor intrinsic_matrix[1, 2] /= zoom_factor data.append({ "name": image_data['file_name'], 'keypoint2d': keypoint2d.tolist(), 'keypoint3d': keypoint3d_camera.tolist(), 'intrinsic_matrix': intrinsic_matrix.tolist(), }) with open(os.path.join(root, "annotations", "keypoints2d_{}.json".format(part)), "w") as f: json.dump(data, f) ================================================ FILE: tllib/vision/datasets/keypoint_detection/keypoint_dataset.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from abc import ABC import numpy as np from torch.utils.data.dataset import Dataset from webcolors import name_to_rgb import cv2 class KeypointDataset(Dataset, ABC): """A generic dataset class for image keypoint detection Args: root (str): Root directory of dataset num_keypoints (int): Number of keypoints samples (list): list of data transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`. image_size (tuple): (width, height) of the image. Default: (256, 256) heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 keypoints_group (dict): a dict that stores the index of different types of keypoints colored_skeleton (dict): a dict that stores the index and color of different skeleton """ def __init__(self, root, num_keypoints, samples, transforms=None, image_size=(256, 256), heatmap_size=(64, 64), sigma=2, keypoints_group=None, colored_skeleton=None): self.root = root self.num_keypoints = num_keypoints self.samples = samples self.transforms = transforms self.image_size = image_size self.heatmap_size = heatmap_size self.sigma = sigma self.keypoints_group = keypoints_group self.colored_skeleton = colored_skeleton def __len__(self): return len(self.samples) def visualize(self, image, keypoints, filename): """Visualize an image with its keypoints, and store the result into a file Args: image (PIL.Image): keypoints (torch.Tensor): keypoints in shape K x 2 filename (str): the name of file to store """ assert self.colored_skeleton is not None image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR).copy() for (_, (line, color)) in self.colored_skeleton.items(): for i in range(len(line) - 1): start, end = keypoints[line[i]], keypoints[line[i + 1]] cv2.line(image, (int(start[0]), int(start[1])), (int(end[0]), int(end[1])), color=name_to_rgb(color), thickness=3) for keypoint in keypoints: cv2.circle(image, (int(keypoint[0]), int(keypoint[1])), 3, name_to_rgb('black'), 1) cv2.imwrite(filename, image) def group_accuracy(self, accuracies): """ Group the accuracy of K keypoints into different kinds. Args: accuracies (list): accuracy of the K keypoints Returns: accuracy of ``N=len(keypoints_group)`` kinds of keypoints """ grouped_accuracies = dict() for name, keypoints in self.keypoints_group.items(): grouped_accuracies[name] = sum([accuracies[idx] for idx in keypoints]) / len(keypoints) return grouped_accuracies class Body16KeypointDataset(KeypointDataset, ABC): """ Dataset with 16 body keypoints. """ # TODO: add image head = (9,) shoulder = (12, 13) elbow = (11, 14) wrist = (10, 15) hip = (2, 3) knee = (1, 4) ankle = (0, 5) all = (12, 13, 11, 14, 10, 15, 2, 3, 1, 4, 0, 5) right_leg = (0, 1, 2, 8) left_leg = (5, 4, 3, 8) backbone = (8, 9) right_arm = (10, 11, 12, 8) left_arm = (15, 14, 13, 8) def __init__(self, root, samples, **kwargs): colored_skeleton = { "right_leg": (self.right_leg, 'yellow'), "left_leg": (self.left_leg, 'green'), "backbone": (self.backbone, 'blue'), "right_arm": (self.right_arm, 'purple'), "left_arm": (self.left_arm, 'red'), } keypoints_group = { "head": self.head, "shoulder": self.shoulder, "elbow": self.elbow, "wrist": self.wrist, "hip": self.hip, "knee": self.knee, "ankle": self.ankle, "all": self.all } super(Body16KeypointDataset, self).__init__(root, 16, samples, keypoints_group=keypoints_group, colored_skeleton=colored_skeleton, **kwargs) class Hand21KeypointDataset(KeypointDataset, ABC): """ Dataset with 21 hand keypoints. """ # TODO: add image MCP = (1, 5, 9, 13, 17) PIP = (2, 6, 10, 14, 18) DIP = (3, 7, 11, 15, 19) fingertip = (4, 8, 12, 16, 20) all = tuple(range(21)) thumb = (0, 1, 2, 3, 4) index_finger = (0, 5, 6, 7, 8) middle_finger = (0, 9, 10, 11, 12) ring_finger = (0, 13, 14, 15, 16) little_finger = (0, 17, 18, 19, 20) def __init__(self, root, samples, **kwargs): colored_skeleton = { "thumb": (self.thumb, 'yellow'), "index_finger": (self.index_finger, 'green'), "middle_finger": (self.middle_finger, 'blue'), "ring_finger": (self.ring_finger, 'purple'), "little_finger": (self.little_finger, 'red'), } keypoints_group = { "MCP": self.MCP, "PIP": self.PIP, "DIP": self.DIP, "fingertip": self.fingertip, "all": self.all } super(Hand21KeypointDataset, self).__init__(root, 21, samples, keypoints_group=keypoints_group, colored_skeleton=colored_skeleton, **kwargs) ================================================ FILE: tllib/vision/datasets/keypoint_detection/lsp.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import scipy.io as scio import os from PIL import ImageFile import torch from .keypoint_dataset import Body16KeypointDataset from ...transforms.keypoint_detection import * from .util import * from .._util import download as download_data, check_exits ImageFile.LOAD_TRUNCATED_IMAGES = True class LSP(Body16KeypointDataset): """`Leeds Sports Pose Dataset `_ Args: root (str): Root directory of dataset split (str, optional): PlaceHolder. task (str, optional): Placeholder. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transforms (callable, optional): PlaceHolder. heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 .. note:: In `root`, there will exist following files after downloading. :: lsp/ images/ joints.mat .. note:: LSP is only used for target domain. Due to the small dataset size, the whole dataset is used no matter what ``split`` is. Also, the transform is fixed. """ def __init__(self, root, split='train', task='all', download=True, image_size=(256, 256), transforms=None, **kwargs): if download: download_data(root, "images", "lsp_dataset.zip", "https://cloud.tsinghua.edu.cn/f/46ea73c89abc46bfb125/?dl=1") else: check_exits(root, "lsp") assert split in ['train', 'test', 'all'] self.split = split samples = [] annotations = scio.loadmat(os.path.join(root, "joints.mat"))['joints'].transpose((2, 1, 0)) for i in range(0, 2000): image = "im{0:04d}.jpg".format(i+1) annotation = annotations[i] samples.append((image, annotation)) self.joints_index = (0, 1, 2, 3, 4, 5, 13, 13, 12, 13, 6, 7, 8, 9, 10, 11) self.visible = np.array([1.] * 6 + [0, 0] + [1.] * 8, dtype=np.float32) normalize = Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transforms = Compose([ ResizePad(image_size[0]), ToTensor(), normalize ]) super(LSP, self).__init__(root, samples, transforms=transforms, image_size=image_size, **kwargs) def __getitem__(self, index): sample = self.samples[index] image_name = sample[0] image = Image.open(os.path.join(self.root, "images", image_name)) keypoint2d = sample[1][self.joints_index, :2] image, data = self.transforms(image, keypoint2d=keypoint2d) keypoint2d = data['keypoint2d'] visible = self.visible * (1-sample[1][self.joints_index, 2]) visible = visible[:, np.newaxis] # 2D heatmap target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size) target = torch.from_numpy(target) target_weight = torch.from_numpy(target_weight) meta = { 'image': image_name, 'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2) 'keypoint3d': np.zeros((self.num_keypoints, 3)).astype(keypoint2d.dtype), # (NUM_KEYPOINTS x 3) } return image, target, target_weight, meta ================================================ FILE: tllib/vision/datasets/keypoint_detection/rendered_hand_pose.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch import os import pickle from .._util import download as download_data, check_exits from ...transforms.keypoint_detection import * from .keypoint_dataset import Hand21KeypointDataset from .util import * class RenderedHandPose(Hand21KeypointDataset): """`Rendered Handpose Dataset `_ Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``. task (str, optional): Placeholder. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`. image_size (tuple): (width, height) of the image. Default: (256, 256) heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 .. note:: In `root`, there will exist following files after downloading. :: RHD_published_v2/ training/ evaluation/ """ def __init__(self, root, split='train', task='all', download=True, **kwargs): if download: download_data(root, "RHD_published_v2", "RHD_v1-1.zip", "https://lmb.informatik.uni-freiburg.de/data/RenderedHandpose/RHD_v1-1.zip") else: check_exits(root, "RHD_published_v2") root = os.path.join(root, "RHD_published_v2") assert split in ['train', 'test', 'all'] self.split = split if split == 'all': samples = self.get_samples(root, 'train') + self.get_samples(root, 'test') else: samples = self.get_samples(root, split) super(RenderedHandPose, self).__init__( root, samples, **kwargs) def __getitem__(self, index): sample = self.samples[index] image_name = sample['name'] image_path = os.path.join(self.root, image_name) image = Image.open(image_path) keypoint3d_camera = np.array(sample['keypoint3d']) # NUM_KEYPOINTS x 3 keypoint2d = np.array(sample['keypoint2d']) # NUM_KEYPOINTS x 2 intrinsic_matrix = np.array(sample['intrinsic_matrix']) Zc = keypoint3d_camera[:, 2] # Crop the images such that the hand is at the center of the image # The images will be 1.5 times larger than the hand # The crop process will change Xc and Yc, leaving Zc with no changes bounding_box = get_bounding_box(keypoint2d) w, h = image.size left, upper, right, lower = scale_box(bounding_box, w, h, 1.5) image, keypoint2d = crop(image, upper, left, lower - upper, right - left, keypoint2d) # Change all hands to right hands if sample['left'] is False: image, keypoint2d = hflip(image, keypoint2d) image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix) keypoint2d = data['keypoint2d'] intrinsic_matrix = data['intrinsic_matrix'] keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc) # noramlize 2D pose: visible = np.array(sample['visible'], dtype=np.float32) visible = visible[:, np.newaxis] # 2D heatmap target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size) target = torch.from_numpy(target) target_weight = torch.from_numpy(target_weight) # normalize 3D pose: # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system # and make distance between wrist and middle finger MCP joint to be of length 1 keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :] keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2)) z = keypoint3d_n[:, 2] meta = { 'image': image_name, 'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2) 'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3) 'z': z, } return image, target, target_weight, meta def get_samples(self, root, task, min_size=64): if task == 'train': set = 'training' else: set = 'evaluation' # load annotations of this set with open(os.path.join(root, set, 'anno_%s.pickle' % set), 'rb') as fi: anno_all = pickle.load(fi) samples = [] left_hand_index = [0, 4, 3, 2, 1, 8, 7, 6, 5, 12, 11, 10, 9, 16, 15, 14, 13, 20, 19, 18, 17] right_hand_index = [i+21 for i in left_hand_index] for sample_id, anno in anno_all.items(): image_name = os.path.join(set, 'color', '%.5d.png' % sample_id) mask_name = os.path.join(set, 'mask', '%.5d.png' % sample_id) keypoint2d = anno['uv_vis'][:, :2] keypoint3d = anno['xyz'] intrinsic_matrix = anno['K'] visible = anno['uv_vis'][:, 2] left_hand_keypoint2d = keypoint2d[left_hand_index] # NUM_KEYPOINTS x 2 left_box = get_bounding_box(left_hand_keypoint2d) right_hand_keypoint2d = keypoint2d[right_hand_index] # NUM_KEYPOINTS x 2 right_box = get_bounding_box(right_hand_keypoint2d) w, h = 320, 320 scaled_left_box = scale_box(left_box, w, h, 1.5) left, upper, right, lower = scaled_left_box size = max(right - left, lower - upper) if size > min_size and np.sum(visible[left_hand_index]) > 16 and area(*intersection(scaled_left_box, right_box)) / area(*scaled_left_box) < 0.3: sample = { 'name': image_name, 'mask_name': mask_name, 'keypoint2d': left_hand_keypoint2d, 'visible': visible[left_hand_index], 'keypoint3d': keypoint3d[left_hand_index], 'intrinsic_matrix': intrinsic_matrix, 'left': True } samples.append(sample) scaled_right_box = scale_box(right_box, w, h, 1.5) left, upper, right, lower = scaled_right_box size = max(right - left, lower - upper) if size > min_size and np.sum(visible[right_hand_index]) > 16 and area(*intersection(scaled_right_box, left_box)) / area(*scaled_right_box) < 0.3: sample = { 'name': image_name, 'mask_name': mask_name, 'keypoint2d': right_hand_keypoint2d, 'visible': visible[right_hand_index], 'keypoint3d': keypoint3d[right_hand_index], 'intrinsic_matrix': intrinsic_matrix, 'left': False } samples.append(sample) return samples ================================================ FILE: tllib/vision/datasets/keypoint_detection/surreal.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os import json from PIL import ImageFile import torch from ...transforms.keypoint_detection import * from .util import * from .._util import download as download_data, check_exits from .keypoint_dataset import Body16KeypointDataset ImageFile.LOAD_TRUNCATED_IMAGES = True class SURREAL(Body16KeypointDataset): """`Surreal Dataset `_ Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, ``test``, or ``all``. Default: ``train``. task (str, optional): Placeholder. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transforms (callable, optional): A function/transform that takes in a dict (which contains PIL image and its labels) and returns a transformed version. E.g, :class:`~tllib.vision.transforms.keypoint_detection.Resize`. image_size (tuple): (width, height) of the image. Default: (256, 256) heatmap_size (tuple): (width, height) of the heatmap. Default: (64, 64) sigma (int): sigma parameter when generate the heatmap. Default: 2 .. note:: We found that the original Surreal image is in high resolution while most part in an image is background, thus we crop the image and keep only the surrounding area of hands (1.5x bigger than hands) to speed up training. .. note:: In `root`, there will exist following files after downloading. :: train/ test/ val/ """ def __init__(self, root, split='train', task='all', download=True, **kwargs): assert split in ['train', 'test', 'val'] self.split = split if download: download_data(root, "train/run0", "train0.tgz", "https://cloud.tsinghua.edu.cn/f/b13604f06ff1445c830a/?dl=1") download_data(root, "train/run1", "train1.tgz", "https://cloud.tsinghua.edu.cn/f/919aefe2de3541c3b940/?dl=1") download_data(root, "train/run1", "train2.tgz", "https://cloud.tsinghua.edu.cn/f/34864760ad4945b9bcd6/?dl=1") download_data(root, "val", "val.tgz", "https://cloud.tsinghua.edu.cn/f/16b20f2e76684f848dc1/?dl=1") download_data(root, "test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/36c72d86e43540e0a913/?dl=1") else: check_exits(root, "train/run0") check_exits(root, "train/run1") check_exits(root, "train/run2") check_exits(root, "val") check_exits(root, "test") all_samples = [] for part in [0, 1, 2]: annotation_file = os.path.join(root, split, 'run{}.json'.format(part)) print("loading", annotation_file) with open(annotation_file) as f: samples = json.load(f) for sample in samples: sample["image_path"] = os.path.join(root, self.split, 'run{}'.format(part), sample['name']) all_samples.extend(samples) random.seed(42) random.shuffle(all_samples) samples_len = len(all_samples) samples_split = min(int(samples_len * 0.2), 3200) if self.split == 'train': all_samples = all_samples[samples_split:] elif self.split == 'test': all_samples = all_samples[:samples_split] self.joints_index = (7, 4, 1, 2, 5, 8, 0, 9, 12, 15, 20, 18, 13, 14, 19, 21) super(SURREAL, self).__init__(root, all_samples, **kwargs) def __getitem__(self, index): sample = self.samples[index] image_name = sample['name'] image_path = sample['image_path'] image = Image.open(image_path) keypoint3d_camera = np.array(sample['keypoint3d'])[self.joints_index, :] # NUM_KEYPOINTS x 3 keypoint2d = np.array(sample['keypoint2d'])[self.joints_index, :] # NUM_KEYPOINTS x 2 intrinsic_matrix = np.array(sample['intrinsic_matrix']) Zc = keypoint3d_camera[:, 2] image, data = self.transforms(image, keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix) keypoint2d = data['keypoint2d'] intrinsic_matrix = data['intrinsic_matrix'] keypoint3d_camera = keypoint2d_to_3d(keypoint2d, intrinsic_matrix, Zc) # noramlize 2D pose: visible = np.array([1.] * 16, dtype=np.float32) visible = visible[:, np.newaxis] # 2D heatmap target, target_weight = generate_target(keypoint2d, visible, self.heatmap_size, self.sigma, self.image_size) target = torch.from_numpy(target) target_weight = torch.from_numpy(target_weight) # normalize 3D pose: # put middle finger metacarpophalangeal (MCP) joint in the center of the coordinate system # and make distance between wrist and middle finger MCP joint to be of length 1 keypoint3d_n = keypoint3d_camera - keypoint3d_camera[9:10, :] keypoint3d_n = keypoint3d_n / np.sqrt(np.sum(keypoint3d_n[0, :] ** 2)) meta = { 'image': image_name, 'keypoint2d': keypoint2d, # (NUM_KEYPOINTS x 2) 'keypoint3d': keypoint3d_n, # (NUM_KEYPOINTS x 3) } return image, target, target_weight, meta def __len__(self): return len(self.samples) ================================================ FILE: tllib/vision/datasets/keypoint_detection/util.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import numpy as np import cv2 def generate_target(joints, joints_vis, heatmap_size, sigma, image_size): """Generate heatamap for joints. Args: joints: (K, 2) joints_vis: (K, 1) heatmap_size: W, H sigma: image_size: Returns: """ num_joints = joints.shape[0] target_weight = np.ones((num_joints, 1), dtype=np.float32) target_weight[:, 0] = joints_vis[:, 0] target = np.zeros((num_joints, heatmap_size[1], heatmap_size[0]), dtype=np.float32) tmp_size = sigma * 3 image_size = np.array(image_size) heatmap_size = np.array(heatmap_size) for joint_id in range(num_joints): feat_stride = image_size / heatmap_size mu_x = int(joints[joint_id][0] / feat_stride[0] + 0.5) mu_y = int(joints[joint_id][1] / feat_stride[1] + 0.5) # Check that any part of the gaussian is in-bounds ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)] br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)] if mu_x >= heatmap_size[0] or mu_y >= heatmap_size[1] \ or mu_x < 0 or mu_y < 0: # If not, just return the image as is target_weight[joint_id] = 0 continue # Generate gaussian size = 2 * tmp_size + 1 x = np.arange(0, size, 1, np.float32) y = x[:, np.newaxis] x0 = y0 = size // 2 # The gaussian is not normalized, we want the center value to equal 1 g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) # Usable gaussian range g_x = max(0, -ul[0]), min(br[0], heatmap_size[0]) - ul[0] g_y = max(0, -ul[1]), min(br[1], heatmap_size[1]) - ul[1] # Image range img_x = max(0, ul[0]), min(br[0], heatmap_size[0]) img_y = max(0, ul[1]), min(br[1], heatmap_size[1]) v = target_weight[joint_id] if v > 0.5: target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \ g[g_y[0]:g_y[1], g_x[0]:g_x[1]] return target, target_weight def keypoint2d_to_3d(keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, Zc: np.ndarray): """Convert 2D keypoints to 3D keypoints""" uv1 = np.concatenate([np.copy(keypoint2d), np.ones((keypoint2d.shape[0], 1))], axis=1).T * Zc # 3 x NUM_KEYPOINTS xyz = np.matmul(np.linalg.inv(intrinsic_matrix), uv1).T # NUM_KEYPOINTS x 3 return xyz def keypoint3d_to_2d(keypoint3d: np.ndarray, intrinsic_matrix: np.ndarray): """Convert 3D keypoints to 2D keypoints""" keypoint2d = np.matmul(intrinsic_matrix, keypoint3d.T).T # NUM_KEYPOINTS x 3 keypoint2d = keypoint2d[:, :2] / keypoint2d[:, 2:3] # NUM_KEYPOINTS x 2 return keypoint2d def scale_box(box, image_width, image_height, scale): """ Change `box` to a square box. The side with of the square box will be `scale` * max(w, h) where w and h is the width and height of the origin box """ left, upper, right, lower = box center_x, center_y = (left + right) / 2, (upper + lower) / 2 w, h = right - left, lower - upper side_with = min(round(scale * max(w, h)), min(image_width, image_height)) left = round(center_x - side_with / 2) right = left + side_with - 1 upper = round(center_y - side_with / 2) lower = upper + side_with - 1 if left < 0: left = 0 right = side_with - 1 if right >= image_width: right = image_width - 1 left = image_width - side_with if upper < 0: upper = 0 lower = side_with -1 if lower >= image_height: lower = image_height - 1 upper = image_height - side_with return left, upper, right, lower def get_bounding_box(keypoint2d: np.array): """Get the bounding box for keypoints""" left = np.min(keypoint2d[:, 0]) right = np.max(keypoint2d[:, 0]) upper = np.min(keypoint2d[:, 1]) lower = np.max(keypoint2d[:, 1]) return left, upper, right, lower def visualize_heatmap(image, heatmaps, filename): image = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR).copy() H, W = heatmaps.shape[1], heatmaps.shape[2] resized_image = cv2.resize(image, (int(W), int(H))) heatmaps = heatmaps.mul(255).clamp(0, 255).byte().cpu().numpy() for k in range(heatmaps.shape[0]): heatmap = heatmaps[k] colored_heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) masked_image = colored_heatmap * 0.7 + resized_image * 0.3 cv2.imwrite(filename.format(k), masked_image) def area(left, upper, right, lower): return max(right - left + 1, 0) * max(lower - upper + 1, 0) def intersection(box_a, box_b): left_a, upper_a, right_a, lower_a = box_a left_b, upper_b, right_b, lower_b = box_b return max(left_a, left_b), max(upper_a, upper_b), min(right_a, right_b), min(lower_a, lower_b) ================================================ FILE: tllib/vision/datasets/object_detection/__init__.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import numpy as np import os import xml.etree.ElementTree as ET from detectron2.data import ( MetadataCatalog, DatasetCatalog, ) from detectron2.utils.file_io import PathManager from detectron2.structures import BoxMode from tllib.vision.datasets._util import download as download_dataset def parse_root_and_file_name(path): path_list = path.split('/') dataset_root = '/'.join(path_list[:-1]) file_name = path_list[-1] if dataset_root == '': dataset_root = '.' return dataset_root, file_name class VOCBase: class_names = ( "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" ) def __init__(self, root, split="trainval", year=2007, ext='.jpg', download=True): self.name = "{}_{}".format(root, split) self.name = self.name.replace(os.path.sep, "_") if self.name not in MetadataCatalog.keys(): register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext) MetadataCatalog.get(self.name).evaluator_type = "pascal_voc" if download: dataset_root, file_name = parse_root_and_file_name(root) download_dataset(dataset_root, file_name, self.archive_name, self.dataset_url) class VOC2007(VOCBase): archive_name = 'VOC2007.tgz' dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1' def __init__(self, root): super(VOC2007, self).__init__(root) class VOC2012(VOCBase): archive_name = 'VOC2012.tgz' dataset_url = 'https://cloud.tsinghua.edu.cn/f/a7e7ab88f727408eaf32/?dl=1' def __init__(self, root): super(VOC2012, self).__init__(root, year=2012) class VOC2007Test(VOCBase): archive_name = 'VOC2007.tgz' dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1' def __init__(self, root): super(VOC2007Test, self).__init__(root, year=2007, split='test') class Clipart(VOCBase): archive_name = 'clipart.zip' dataset_url = 'https://cloud.tsinghua.edu.cn/f/c853a66786e2416a8f18/?dl=1' class VOCPartialBase: class_names = ( "bicycle", "bird", "car", "cat", "dog", "person", ) def __init__(self, root, split="trainval", year=2007, ext='.jpg', download=True): self.name = "{}_{}".format(root, split) self.name = self.name.replace(os.path.sep, "_") if self.name not in MetadataCatalog.keys(): register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext) MetadataCatalog.get(self.name).evaluator_type = "pascal_voc" if download: dataset_root, file_name = parse_root_and_file_name(root) download_dataset(dataset_root, file_name, self.archive_name, self.dataset_url) class VOC2007Partial(VOCPartialBase): archive_name = 'VOC2007.tgz' dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1' def __init__(self, root): super(VOC2007Partial, self).__init__(root) class VOC2012Partial(VOCPartialBase): archive_name = 'VOC2012.tgz' dataset_url = 'https://cloud.tsinghua.edu.cn/f/a7e7ab88f727408eaf32/?dl=1' def __init__(self, root): super(VOC2012Partial, self).__init__(root, year=2012) class VOC2007PartialTest(VOCPartialBase): archive_name = 'VOC2007.tgz' dataset_url = 'https://cloud.tsinghua.edu.cn/f/800a9495d3b74612be3f/?dl=1' def __init__(self, root): super(VOC2007PartialTest, self).__init__(root, year=2007, split='test') class WaterColor(VOCPartialBase): archive_name = 'watercolor.zip' dataset_url = 'https://cloud.tsinghua.edu.cn/f/9f322fd8496f4766ad93/?dl=1' def __init__(self, root): super(WaterColor, self).__init__(root, split='train') class WaterColorTest(VOCPartialBase): archive_name = 'watercolor.zip' dataset_url = 'https://cloud.tsinghua.edu.cn/f/9f322fd8496f4766ad93/?dl=1' def __init__(self, root): super(WaterColorTest, self).__init__(root, split='test') class Comic(VOCPartialBase): archive_name = 'comic.tar' dataset_url = 'https://cloud.tsinghua.edu.cn/f/030d7b4b649f46589b2d/?dl=1' def __init__(self, root): super(Comic, self).__init__(root, split='train') class ComicTest(VOCPartialBase): archive_name = 'comic.tar' dataset_url = 'https://cloud.tsinghua.edu.cn/f/030d7b4b649f46589b2d/?dl=1' def __init__(self, root): super(ComicTest, self).__init__(root, split='test') class CityscapesBase: class_names = ( "bicycle", "bus", "car", "motorcycle", "person", "rider", "train", "truck", ) def __init__(self, root, split="trainval", year=2007, ext='.png'): self.name = "{}_{}".format(root, split) self.name = self.name.replace(os.path.sep, "_") if self.name not in MetadataCatalog.keys(): register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext, bbox_zero_based=True) MetadataCatalog.get(self.name).evaluator_type = "pascal_voc" class Cityscapes(CityscapesBase): def __init__(self, root): super(Cityscapes, self).__init__(root, split="trainval") class CityscapesTest(CityscapesBase): def __init__(self, root): super(CityscapesTest, self).__init__(root, split='test') class FoggyCityscapes(Cityscapes): pass class FoggyCityscapesTest(CityscapesTest): pass class CityscapesCarBase: class_names = ( "car", ) def __init__(self, root, split="trainval", year=2007, ext='.png', bbox_zero_based=True): self.name = "{}_{}".format(root, split) self.name = self.name.replace(os.path.sep, "_") if self.name not in MetadataCatalog.keys(): register_pascal_voc(self.name, root, split, year, class_names=self.class_names, ext=ext, bbox_zero_based=bbox_zero_based) MetadataCatalog.get(self.name).evaluator_type = "pascal_voc" class CityscapesCar(CityscapesCarBase): pass class CityscapesCarTest(CityscapesCarBase): def __init__(self, root): super(CityscapesCarTest, self).__init__(root, split='test') class Sim10kCar(CityscapesCarBase): def __init__(self, root): super(Sim10kCar, self).__init__(root, split='trainval10k', ext='.jpg', bbox_zero_based=False) class KITTICar(CityscapesCarBase): def __init__(self, root): super(KITTICar, self).__init__(root, split='trainval', ext='.jpg', bbox_zero_based=False) class GTA5(CityscapesBase): def __init__(self, root): super(GTA5, self).__init__(root, split="trainval", ext='.jpg') def load_voc_instances(dirname: str, split: str, class_names, ext='.jpg', bbox_zero_based=False): """ Load Pascal VOC detection annotations to Detectron2 format. Args: dirname: Contain "Annotations", "ImageSets", "JPEGImages" split (str): one of "train", "test", "val", "trainval" class_names: list or tuple of class names """ with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f: fileids = np.loadtxt(f, dtype=np.str) # Needs to read many small annotation files. Makes sense at local annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/")) dicts = [] skip_classes = set() for fileid in fileids: anno_file = os.path.join(annotation_dirname, fileid + ".xml") jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ext) with PathManager.open(anno_file) as f: tree = ET.parse(f) r = { "file_name": jpeg_file, "image_id": fileid, "height": int(tree.findall("./size/height")[0].text), "width": int(tree.findall("./size/width")[0].text), } instances = [] for obj in tree.findall("object"): cls = obj.find("name").text if cls not in class_names: skip_classes.add(cls) continue # We include "difficult" samples in training. # Based on limited experiments, they don't hurt accuracy. # difficult = int(obj.find("difficult").text) # if difficult == 1: # continue bbox = obj.find("bndbox") bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]] # Original annotations are integers in the range [1, W or H] # Assuming they mean 1-based pixel indices (inclusive), # a box with annotation (xmin=1, xmax=W) covers the whole image. # In coordinate space this is represented by (xmin=0, xmax=W) if bbox_zero_based is False: bbox[0] -= 1.0 bbox[1] -= 1.0 instances.append( {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS} ) r["annotations"] = instances dicts.append(r) print("Skip classes:", list(skip_classes)) return dicts def register_pascal_voc(name, dirname, split, year, class_names, **kwargs): DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names, **kwargs)) MetadataCatalog.get(name).set( thing_classes=list(class_names), dirname=dirname, year=year, split=split ) ================================================ FILE: tllib/vision/datasets/office31.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import os from .imagelist import ImageList from ._util import download as download_data, check_exits class Office31(ImageList): """Office31 Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \ ``'D'``: dslr and ``'W'``: webcam. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: amazon/ images/ backpack/ *.jpg ... dslr/ webcam/ image_list/ amazon.txt dslr.txt webcam.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/2c1dd9fbcaa9455aa4ad/?dl=1"), ("amazon", "amazon.tgz", "https://cloud.tsinghua.edu.cn/f/ec12dfcddade43ab8101/?dl=1"), ("dslr", "dslr.tgz", "https://cloud.tsinghua.edu.cn/f/a41d818ae2f34da7bb32/?dl=1"), ("webcam", "webcam.tgz", "https://cloud.tsinghua.edu.cn/f/8a41009a166e4131adcd/?dl=1"), ] image_list = { "A": "image_list/amazon.txt", "D": "image_list/dslr.txt", "W": "image_list/webcam.txt" } CLASSES = ['back_pack', 'bike', 'bike_helmet', 'bookcase', 'bottle', 'calculator', 'desk_chair', 'desk_lamp', 'desktop_computer', 'file_cabinet', 'headphones', 'keyboard', 'laptop_computer', 'letter_tray', 'mobile_phone', 'monitor', 'mouse', 'mug', 'paper_notebook', 'pen', 'phone', 'printer', 'projector', 'punchers', 'ring_binder', 'ruler', 'scissors', 'speaker', 'stapler', 'tape_dispenser', 'trash_can'] def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs): assert task in self.image_list data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(Office31, self).__init__(root, Office31.CLASSES, data_list_file=data_list_file, **kwargs) @classmethod def domains(cls): return list(cls.image_list.keys()) ================================================ FILE: tllib/vision/datasets/officecaltech.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from typing import Optional from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS, default_loader from torchvision.datasets.utils import download_and_extract_archive from ._util import check_exits class OfficeCaltech(DatasetFolder): """Office+Caltech Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \ ``'D'``: dslr, ``'W'``:webcam and ``'C'``: caltech. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: amazon/ images/ backpack/ *.jpg ... dslr/ webcam/ caltech/ image_list/ amazon.txt dslr.txt webcam.txt caltech.txt """ directories = { "A": "amazon", "D": "dslr", "W": "webcam", "C": "caltech" } CLASSES = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard', 'laptop_computer', 'monitor', 'mouse', 'mug', 'projector'] def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs): if download: for dir in self.directories.values(): if not os.path.exists(os.path.join(root, dir)): download_and_extract_archive(url="https://cloud.tsinghua.edu.cn/f/eea518fa781a41d1b20e/?dl=1", download_root=os.path.join(root, 'download'), filename="office-caltech.tgz", remove_finished=False, extract_root=root) break else: list(map(lambda dir, _: check_exits(root, dir), self.directories.values())) super(OfficeCaltech, self).__init__( os.path.join(root, self.directories[task]), default_loader, extensions=IMG_EXTENSIONS, **kwargs) self.classes = OfficeCaltech.CLASSES self.class_to_idx = {cls: idx for idx, clss in enumerate(self.classes) for cls in clss} @property def num_classes(self): """Number of classes""" return len(self.classes) @classmethod def domains(cls): return list(cls.directories.keys()) ================================================ FILE: tllib/vision/datasets/officehome.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class OfficeHome(ImageList): """`OfficeHome `_ Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'Ar'``: Art, \ ``'Cl'``: Clipart, ``'Pr'``: Product and ``'Rw'``: Real_World. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: Art/ Alarm_Clock/*.jpg ... Clipart/ Product/ Real_World/ image_list/ Art.txt Clipart.txt Product.txt Real_World.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/1b0171a188944313b1f5/?dl=1"), ("Art", "Art.tgz", "https://cloud.tsinghua.edu.cn/f/6a006656b9a14567ade2/?dl=1"), ("Clipart", "Clipart.tgz", "https://cloud.tsinghua.edu.cn/f/ae88aa31d2d7411dad79/?dl=1"), ("Product", "Product.tgz", "https://cloud.tsinghua.edu.cn/f/f219b0ff35e142b3ab48/?dl=1"), ("Real_World", "Real_World.tgz", "https://cloud.tsinghua.edu.cn/f/6c19f3f15bb24ed3951a/?dl=1") ] image_list = { "Ar": "image_list/Art.txt", "Cl": "image_list/Clipart.txt", "Pr": "image_list/Product.txt", "Rw": "image_list/Real_World.txt", } CLASSES = ['Drill', 'Exit_Sign', 'Bottle', 'Glasses', 'Computer', 'File_Cabinet', 'Shelf', 'Toys', 'Sink', 'Laptop', 'Kettle', 'Folder', 'Keyboard', 'Flipflops', 'Pencil', 'Bed', 'Hammer', 'ToothBrush', 'Couch', 'Bike', 'Postit_Notes', 'Mug', 'Webcam', 'Desk_Lamp', 'Telephone', 'Helmet', 'Mouse', 'Pen', 'Monitor', 'Mop', 'Sneakers', 'Notebook', 'Backpack', 'Alarm_Clock', 'Push_Pin', 'Paper_Clip', 'Batteries', 'Radio', 'Fan', 'Ruler', 'Pan', 'Screwdriver', 'Trash_Can', 'Printer', 'Speaker', 'Eraser', 'Bucket', 'Chair', 'Calendar', 'Calculator', 'Flowers', 'Lamp_Shade', 'Spoon', 'Candles', 'Clipboards', 'Scissors', 'TV', 'Curtains', 'Fork', 'Soda', 'Table', 'Knives', 'Oven', 'Refrigerator', 'Marker'] def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs): assert task in self.image_list data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(OfficeHome, self).__init__(root, OfficeHome.CLASSES, data_list_file=data_list_file, **kwargs) @classmethod def domains(cls): return list(cls.image_list.keys()) ================================================ FILE: tllib/vision/datasets/openset/__init__.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from ..imagelist import ImageList from ..office31 import Office31 from ..officehome import OfficeHome from ..visda2017 import VisDA2017 from typing import Optional, ClassVar, Sequence from copy import deepcopy __all__ = ['Office31', 'OfficeHome', "VisDA2017"] def open_set(dataset_class: ClassVar, public_classes: Sequence[str], private_classes: Optional[Sequence[str]] = ()) -> ClassVar: """ Convert a dataset into its open-set version. In other words, those samples which doesn't belong to `private_classes` will be marked as "unknown". Be aware that `open_set` will change the label number of each category. Args: dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be open-set. public_classes (sequence[str]): A sequence of which categories need to be kept in the open-set dataset.\ Each element of `public_classes` must belong to the `classes` list of `dataset_class`. private_classes (sequence[str], optional): A sequence of which categories need to be marked as "unknown" \ in the open-set dataset. Each element of `private_classes` must belong to the `classes` list of \ `dataset_class`. Default: (). Examples:: >>> public_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard'] >>> private_classes = ['laptop_computer', 'monitor', 'mouse', 'mug', 'projector'] >>> # create a open-set dataset class which has classes >>> # 'back_pack', 'bike', 'calculator', 'headphones', 'keyboard' and 'unknown'. >>> OpenSetOffice31 = open_set(Office31, public_classes, private_classes) >>> # create an instance of the open-set dataset >>> dataset = OpenSetDataset(root="data/office31", task="A") """ if not (issubclass(dataset_class, ImageList)): raise Exception("Only subclass of ImageList can be openset") class OpenSetDataset(dataset_class): def __init__(self, **kwargs): super(OpenSetDataset, self).__init__(**kwargs) samples = [] all_classes = list(deepcopy(public_classes)) + ["unknown"] for (path, label) in self.samples: class_name = self.classes[label] if class_name in public_classes: samples.append((path, all_classes.index(class_name))) elif class_name in private_classes: samples.append((path, all_classes.index("unknown"))) self.samples = samples self.classes = all_classes self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)} return OpenSetDataset def default_open_set(dataset_class: ClassVar, source: bool) -> ClassVar: """ Default open-set used in some paper. Args: dataset_class (class): Dataset class. Currently, dataset_class must be one of :class:`~tllib.vision.datasets.office31.Office31`, :class:`~tllib.vision.datasets.officehome.OfficeHome`, :class:`~tllib.vision.datasets.visda2017.VisDA2017`, source (bool): Whether the dataset is used for source domain or not. """ if dataset_class == Office31: public_classes = Office31.CLASSES[:20] if source: private_classes = () else: private_classes = Office31.CLASSES[20:] elif dataset_class == OfficeHome: public_classes = sorted(OfficeHome.CLASSES)[:25] if source: private_classes = () else: private_classes = sorted(OfficeHome.CLASSES)[25:] elif dataset_class == VisDA2017: public_classes = ('bicycle', 'bus', 'car', 'motorcycle', 'train', 'truck') if source: private_classes = () else: private_classes = ('aeroplane', 'horse', 'knife', 'person', 'plant', 'skateboard') else: raise NotImplementedError("Unknown openset domain adaptation dataset: {}".format(dataset_class.__name__)) return open_set(dataset_class, public_classes, private_classes) ================================================ FILE: tllib/vision/datasets/oxfordflowers.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from .imagelist import ImageList from ._util import download as download_data, check_exits class OxfordFlowers102(ImageList): """ `The Oxford Flowers 102 `_ is a \ consistent of 102 flower categories commonly occurring in the United Kingdom. \ Each class consists of between 40 and 258 images. The images have large scale, \ pose and light variations. In addition, there are categories that have large \ variations within the category and several very similar categories. \ The dataset is divided into a training set, a validation set and a test set. \ The training set and validation set each consist of 10 images per class \ (totalling 1020 images each). \ The test set consists of the remaining 6149 images (minimum 20 per class). Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/161c7b222d6745408201/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/59b6a3fa3dac4404aa3b/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/ec77da479dfb471982fb/?dl=1") ] CLASSES = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris', 'windflower', 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen', 'watercress', 'canna lily', 'hippeastrum', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily'] def __init__(self, root, split='train', download=False, **kwargs): if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(OxfordFlowers102, self).__init__(root, OxfordFlowers102.CLASSES, os.path.join(root, 'image_list', '{}.txt'.format(split)), **kwargs) ================================================ FILE: tllib/vision/datasets/oxfordpets.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class OxfordIIITPets(ImageList): """`The Oxford-IIIT Pets `_ \ is a 37-category pet dataset with roughly 200 images for each class. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. sample_rate (int): The sampling rates to sample random ``training`` images for each category. Choices include 100, 50, 30, 15. Default: 100. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: train/ test/ image_list/ train_100.txt train_50.txt train_30.txt train_15.txt test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/8295cfba35b148529bc3/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/89e422c95cb54fb7b0cc/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/dbf7ac10e25b4262b8e5/?dl=1"), ] image_list = { "train": "image_list/train_100.txt", "train100": "image_list/train_100.txt", "train50": "image_list/train_50.txt", "train30": "image_list/train_30.txt", "train15": "image_list/train_15.txt", "test": "image_list/test.txt", "test100": "image_list/test.txt", } CLASSES = ['Abyssinian', 'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'Bengal', 'Birman', 'Bombay', 'boxer', 'British_Shorthair', 'chihuahua', 'Egyptian_Mau', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'Maine_Coon', 'miniature_pinscher', 'newfoundland', 'Persian', 'pomeranian', 'pug', 'Ragdoll', 'Russian_Blue', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'Siamese', 'Sphynx', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier'] def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False, **kwargs): if split == 'train': list_name = 'train' + str(sample_rate) assert list_name in self.image_list data_list_file = os.path.join(root, self.image_list[list_name]) else: data_list_file = os.path.join(root, self.image_list['test']) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(OxfordIIITPets, self).__init__(root, OxfordIIITPets.CLASSES, data_list_file=data_list_file, **kwargs) ================================================ FILE: tllib/vision/datasets/pacs.py ================================================ from typing import Optional import os from .imagelist import ImageList from ._util import download as download_data, check_exits class PACS(ImageList): """`PACS Dataset `_. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'A'``: amazon, \ ``'D'``: dslr and ``'W'``: webcam. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: art_painting/ dog/ *.jpg ... cartoon/ photo/ sketch image_list/ art_painting.txt cartoon.txt photo.txt sketch.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/603a1fea81f2415ab7e0/?dl=1"), ("art_painting", "art_painting.tgz", "https://cloud.tsinghua.edu.cn/f/46684292e979402b8d87/?dl=1"), ("cartoon", "cartoon.tgz", "https://cloud.tsinghua.edu.cn/f/7bfa413b34ec4f4fa384/?dl=1"), ("photo", "photo.tgz", "https://cloud.tsinghua.edu.cn/f/45f71386a668475d8b42/?dl=1"), ("sketch", "sketch.tgz", "https://cloud.tsinghua.edu.cn/f/4ba559535e4b4b6981e5/?dl=1"), ] image_list = { "A": "image_list/art_painting_{}.txt", "C": "image_list/cartoon_{}.txt", "P": "image_list/photo_{}.txt", "S": "image_list/sketch_{}.txt" } CLASSES = ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person'] def __init__(self, root: str, task: str, split='all', download: Optional[bool] = True, **kwargs): assert task in self.image_list assert split in ["train", "val", "all", "test"] if split == "test": split = "all" data_list_file = os.path.join(root, self.image_list[task].format(split)) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(PACS, self).__init__(root, PACS.CLASSES, data_list_file=data_list_file, target_transform=lambda x: x - 1, **kwargs) @classmethod def domains(cls): return list(cls.image_list.keys()) ================================================ FILE: tllib/vision/datasets/partial/__init__.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from ..imagelist import ImageList from ..office31 import Office31 from ..officehome import OfficeHome from ..visda2017 import VisDA2017 from ..officecaltech import OfficeCaltech from .imagenet_caltech import ImageNetCaltech from .caltech_imagenet import CaltechImageNet from tllib.vision.datasets.partial.imagenet_caltech import ImageNetCaltech from typing import Sequence, ClassVar __all__ = ['Office31', 'OfficeHome', "VisDA2017", "CaltechImageNet", "ImageNetCaltech"] def partial(dataset_class: ClassVar, partial_classes: Sequence[str]) -> ClassVar: """ Convert a dataset into its partial version. In other words, those samples which doesn't belong to `partial_classes` will be discarded. Yet `partial` will not change the label space of `dataset_class`. Args: dataset_class (class): Dataset class. Only subclass of ``ImageList`` can be partial. partial_classes (sequence[str]): A sequence of which categories need to be kept in the partial dataset.\ Each element of `partial_classes` must belong to the `classes` list of `dataset_class`. Examples:: >>> partial_classes = ['back_pack', 'bike', 'calculator', 'headphones', 'keyboard'] >>> # create a partial dataset class >>> PartialOffice31 = partial(Office31, partial_classes) >>> # create an instance of the partial dataset >>> dataset = PartialDataset(root="data/office31", task="A") """ if not (issubclass(dataset_class, ImageList)): raise Exception("Only subclass of ImageList can be partial") class PartialDataset(dataset_class): def __init__(self, **kwargs): super(PartialDataset, self).__init__(**kwargs) assert all([c in self.classes for c in partial_classes]) samples = [] for (path, label) in self.samples: class_name = self.classes[label] if class_name in partial_classes: samples.append((path, label)) self.samples = samples self.partial_classes = partial_classes self.partial_classes_idx = [self.class_to_idx[c] for c in partial_classes] return PartialDataset def default_partial(dataset_class: ClassVar) -> ClassVar: """ Default partial used in some paper. Args: dataset_class (class): Dataset class. Currently, dataset_class must be one of :class:`~tllib.vision.datasets.office31.Office31`, :class:`~tllib.vision.datasets.officehome.OfficeHome`, :class:`~tllib.vision.datasets.visda2017.VisDA2017`, :class:`~tllib.vision.datasets.partial.imagenet_caltech.ImageNetCaltech` and :class:`~tllib.vision.datasets.partial.caltech_imagenet.CaltechImageNet`. """ if dataset_class == Office31: kept_classes = OfficeCaltech.CLASSES elif dataset_class == OfficeHome: kept_classes = sorted(OfficeHome.CLASSES)[:25] elif dataset_class == VisDA2017: kept_classes = sorted(VisDA2017.CLASSES)[:6] elif dataset_class in [ImageNetCaltech, CaltechImageNet]: kept_classes = dataset_class.CLASSES else: raise NotImplementedError("Unknown partial domain adaptation dataset: {}".format(dataset_class.__name__)) return partial(dataset_class, kept_classes) ================================================ FILE: tllib/vision/datasets/partial/caltech_imagenet.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional import os from ..imagelist import ImageList from .._util import download as download_data, check_exits _CLASSES = ['ak47', 'american flag', 'backpack', 'baseball bat', 'baseball glove', 'basketball hoop', 'bat', 'bathtub', 'bear', 'beer mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai 101', 'boom box', 'bowling ball', 'bowling pin', 'boxing glove', 'brain 101', 'breadmaker', 'buddha 101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car tire', 'cartman', 'cd', 'centipede', 'cereal box', 'chandelier 101', 'chess board', 'chimp', 'chopsticks', 'cockroach', 'coffee mug', 'coffin', 'coin', 'comet', 'computer keyboard', 'computer monitor', 'computer mouse', 'conch', 'cormorant', 'covered wagon', 'cowboy hat', 'crab 101', 'desk globe', 'diamond ring', 'dice', 'dog', 'dolphin 101', 'doorknob', 'drinking straw', 'duck', 'dumb bell', 'eiffel tower', 'electric guitar 101', 'elephant 101', 'elk', 'ewer 101', 'eyeglasses', 'fern', 'fighter jet', 'fire extinguisher', 'fire hydrant', 'fire truck', 'fireworks', 'flashlight', 'floppy disk', 'football helmet', 'french horn', 'fried egg', 'frisbee', 'frog', 'frying pan', 'galaxy', 'gas pump', 'giraffe', 'goat', 'golden gate bridge', 'goldfish', 'golf ball', 'goose', 'gorilla', 'grand piano 101', 'grapes', 'grasshopper', 'guitar pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill 101', 'head phones', 'helicopter 101', 'hibiscus', 'homer simpson', 'horse', 'horseshoe crab', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house fly', 'human skeleton', 'hummingbird', 'ibis 101', 'ice cream cone', 'iguana', 'ipod', 'iris', 'jesus christ', 'joy stick', 'kangaroo 101', 'kayak', 'ketch 101', 'killer whale', 'knife', 'ladder', 'laptop 101', 'lathe', 'leopards 101', 'license plate', 'lightbulb', 'light house', 'lightning', 'llama 101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah 101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes 101', 'mountain bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm pilot', 'palm tree', 'paperclip', 'paper shredder', 'pci card', 'penguin', 'people', 'pez dispenser', 'photocopier', 'picnic table', 'playing card', 'porcupine', 'pram', 'praying mantis', 'pyramid', 'raccoon', 'radio telescope', 'rainbow', 'refrigerator', 'revolver 101', 'rifle', 'rotary phone', 'roulette wheel', 'saddle', 'saturn', 'school bus', 'scorpion 101', 'screwdriver', 'segway', 'self propelled lawn mower', 'sextant', 'sheet music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer ball', 'socks', 'soda can', 'spaghetti', 'speed boat', 'spider', 'spoon', 'stained glass', 'starfish 101', 'steering wheel', 'stirrups', 'sunflower 101', 'superman', 'sushi', 'swan', 'swiss army knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy bear', 'teepee', 'telephone box', 'tennis ball', 'tennis court', 'tennis racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top hat', 'touring bike', 'tower pisa', 'traffic light', 'treadmill', 'triceratops', 'tricycle', 'trilobite 101', 'tripod', 't shirt', 'tuning fork', 'tweezer', 'umbrella 101', 'unicorn', 'vcr', 'video projector', 'washing machine', 'watch 101', 'waterfall', 'watermelon', 'welding mask', 'wheelbarrow', 'windmill', 'wine bottle', 'xylophone', 'yarmulke', 'yo yo', 'zebra', 'airplanes 101', 'car side 101', 'faces easy 101', 'greyhound', 'tennis shoes', 'toad'] class CaltechImageNet(ImageList): """Caltech-ImageNet is constructed from `Caltech-256 `_ and `ImageNet-1K `_ . They share 84 common classes. Caltech-ImageNet keeps all classes of Caltech-256. The label is based on the Caltech256 (class 0-255) . The private classes of ImageNet-1K is discarded. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \ ``'I'``: ImageNet-1K validation set. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically. In `root`, there will exist following files after downloading. :: train/ n01440764/ ... val/ 256_ObjectCategories/ 001.ak47/ ... image_list/ caltech_256_list.txt ... """ image_list = { "C": "image_list/caltech_256_list.txt", "I": "image_list/imagenet_val_84_list.txt", } CLASSES = _CLASSES def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs): assert task in self.image_list data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), download_list)) if not os.path.exists(os.path.join(root, 'val')): print("Please put train and val directory of ImageNet-1K manually under {} " "since ImageNet-1K is no longer publicly accessible.".format(root)) exit(-1) super(CaltechImageNet, self).__init__(root, CaltechImageNet.CLASSES, data_list_file=data_list_file, **kwargs) class CaltechImageNetUniversal(ImageList): """Caltech-ImageNet-Universal is constructed from `Caltech-256 `_ and `ImageNet-1K `_ . They share 84 common classes. Caltech-ImageNet keeps all classes of Caltech-256. The label is based on the Caltech256 (class 0-255) . The private classes of ImageNet-1K is grouped into class 256 ("unknown"). Thus, CaltechImageNetUniversal has 257 classes in total. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \ ``'I'``: ImageNet-1K validation set. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically. In `root`, there will exist following files after downloading. :: train/ n01440764/ ... val/ 256_ObjectCategories/ 001.ak47/ ... image_list/ caltech_256_list.txt ... """ image_list = { "C": "image_list/caltech_256_list.txt", "I": "image_list/imagenet_val_85_list.txt", } CLASSES = _CLASSES + ['unknown'] def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs): assert task in self.image_list data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), download_list)) if not os.path.exists(os.path.join(root, 'val')): print("Please put train and val directory of ImageNet-1K manually under {} " "since ImageNet-1K is no longer publicly accessible.".format(root)) exit(-1) super(CaltechImageNetUniversal, self).__init__(root, CaltechImageNetUniversal.CLASSES, data_list_file=data_list_file, **kwargs) download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/a0d7ea37026946f98965/?dl=1"), ("256_ObjectCategories", "256_ObjectCategories.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar"), ] ================================================ FILE: tllib/vision/datasets/partial/imagenet_caltech.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from typing import Optional from ..imagelist import ImageList from .._util import download as download_data, check_exits _CLASSES = [c[0] for c in [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'), ('great white shark', 'white shark', 'man-eater', 'man-eating shark', 'Carcharodon carcharias'), ('tiger shark', 'Galeocerdo cuvieri'), ('hammerhead', 'hammerhead shark'), ('electric ray', 'crampfish', 'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',), ('ostrich', 'Struthio camelus'), ('brambling', 'Fringilla montifringilla'), ('goldfinch', 'Carduelis carduelis'), ('house finch', 'linnet', 'Carpodacus mexicanus'), ('junco', 'snowbird'), ('indigo bunting', 'indigo finch', 'indigo bird', 'Passerina cyanea'), ('robin', 'American robin', 'Turdus migratorius'), ('bulbul',), ('jay',), ('magpie',), ('chickadee',), ('water ouzel', 'dipper'), ('kite',), ('bald eagle', 'American eagle', 'Haliaeetus leucocephalus'), ('vulture',), ('great grey owl', 'great gray owl', 'Strix nebulosa'), ('European fire salamander', 'Salamandra salamandra'), ('common newt', 'Triturus vulgaris'), ('eft',), ('spotted salamander', 'Ambystoma maculatum'), ('axolotl', 'mud puppy', 'Ambystoma mexicanum'), ('bullfrog', 'Rana catesbeiana'), ('tree frog', 'tree-frog'), ('tailed frog', 'bell toad', 'ribbed toad', 'tailed toad', 'Ascaphus trui'), ('loggerhead', 'loggerhead turtle', 'Caretta caretta'), ('leatherback turtle', 'leatherback', 'leathery turtle', 'Dermochelys coriacea'), ('mud turtle',), ('terrapin',), ('box turtle', 'box tortoise'), ('banded gecko',), ('common iguana', 'iguana', 'Iguana iguana'), ('American chameleon', 'anole', 'Anolis carolinensis'), ('whiptail', 'whiptail lizard'), ('agama',), ('frilled lizard', 'Chlamydosaurus kingi'), ('alligator lizard',), ('Gila monster', 'Heloderma suspectum'), ('green lizard', 'Lacerta viridis'), ('African chameleon', 'Chamaeleo chamaeleon'), ('Komodo dragon', 'Komodo lizard', 'dragon lizard', 'giant lizard', 'Varanus komodoensis'), ('African crocodile', 'Nile crocodile', 'Crocodylus niloticus'), ('American alligator', 'Alligator mississipiensis'), ('triceratops',), ('thunder snake', 'worm snake', 'Carphophis amoenus'), ('ringneck snake', 'ring-necked snake', 'ring snake'), ('hognose snake', 'puff adder', 'sand viper'), ('green snake', 'grass snake'), ('king snake', 'kingsnake'), ('garter snake', 'grass snake'), ('water snake',), ('vine snake',), ('night snake', 'Hypsiglena torquata'), ('boa constrictor', 'Constrictor constrictor'), ('rock python', 'rock snake', 'Python sebae'), ('Indian cobra', 'Naja naja'), ('green mamba',), ('sea snake',), ('horned viper', 'cerastes', 'sand viper', 'horned asp', 'Cerastes cornutus'), ('diamondback', 'diamondback rattlesnake', 'Crotalus adamanteus'), ('sidewinder', 'horned rattlesnake', 'Crotalus cerastes'), ('trilobite',), ('harvestman', 'daddy longlegs', 'Phalangium opilio'), ('scorpion',), ('black and gold garden spider', 'Argiope aurantia'), ('barn spider', 'Araneus cavaticus'), ('garden spider', 'Aranea diademata'), ('black widow', 'Latrodectus mactans'), ('tarantula',), ('wolf spider', 'hunting spider'), ('tick',), ('centipede',), ('black grouse',), ('ptarmigan',), ('ruffed grouse', 'partridge', 'Bonasa umbellus'), ('prairie chicken', 'prairie grouse', 'prairie fowl'), ('peacock',), ('quail',), ('partridge',), ('African grey', 'African gray', 'Psittacus erithacus'), ('macaw',), ('sulphur-crested cockatoo', 'Kakatoe galerita', 'Cacatua galerita'), ('lorikeet',), ('coucal',), ('bee eater',), ('hornbill',), ('hummingbird',), ('jacamar',), ('toucan',), ('drake',), ('red-breasted merganser', 'Mergus serrator'), ('goose',), ('black swan', 'Cygnus atratus'), ('tusker',), ('echidna', 'spiny anteater', 'anteater'), ( 'platypus', 'duckbill', 'duckbilled platypus', 'duck-billed platypus', 'Ornithorhynchus anatinus'), ('wallaby', 'brush kangaroo'), ('koala', 'koalabear', 'kangaroo bear', 'native bear', 'Phascolarctos cinereus'), ('wombat',), ('jellyfish',), ('sea anemone', 'anemone'), ('brain coral',), ('flatworm', 'platyhelminth'), ('nematode', 'nematode worm', 'roundworm'), ('conch',), ('snail',), ('slug',), ('sea slug', 'nudibranch'), ('chiton', 'coat-of-mail shell', 'sea cradle', 'polyplacophore'), ('chambered nautilus', 'pearly nautilus', 'nautilus'), ('Dungeness crab', 'Cancer magister'), ('rock crab', 'Cancer irroratus'), ('fiddler crab',), ( 'king crab', 'Alaska crab', 'Alaskan king crab', 'Alaska king crab', 'Paralithodes camtschatica'), ('American lobster', 'Northernlobster', 'Maine lobster', 'Homarus americanus'), ('spiny lobster', 'langouste', 'rock lobster', 'crawfish', 'crayfish', 'sea crawfish'), ('crayfish', 'crawfish', 'crawdad', 'crawdaddy'), ('hermit crab',), ('isopod',), ('white stork', 'Ciconia ciconia'), ('black stork', 'Ciconia nigra'), ('spoonbill',), ('flamingo',), ('little blue heron', 'Egretta caerulea'), ('American egret', 'great white heron', 'Egretta albus'), ('bittern',), ('crane',), ('limpkin', 'Aramus pictus'), ('European gallinule', 'Porphyrio porphyrio'), ('American coot', 'marsh hen', 'mud hen', 'water hen', 'Fulica americana'), ('bustard',), ('ruddy turnstone', 'Arenaria interpres'), ('red-backed sandpiper', 'dunlin', 'Erolia alpina'), ('redshank', 'Tringa totanus'), ('dowitcher',), ('oystercatcher', 'oyster catcher'), ('pelican',), ('king penguin', 'Aptenodytes patagonica'), ('albatross', 'mollymawk'), ('grey whale', 'gray whale', 'devilfish', 'Eschrichtius gibbosus', 'Eschrichtius robustus'), ('killer whale', 'killer', 'orca', 'grampus', 'sea wolf', 'Orcinusorca'), ('dugong', 'Dugong dugon'), ('sea lion',), ('Chihuahua',), ('Japanese spaniel',), ('Maltese dog', 'Maltese terrier', 'Maltese'), ('Pekinese', 'Pekingese', 'Peke'), ('Shih-Tzu',), ('Blenheim spaniel',), ('papillon',), ('toy terrier',), ('Rhodesian ridgeback',), ('Afghan hound', 'Afghan'), ('basset', 'basset hound'), ('beagle',), ('bloodhound', 'sleuthhound'), ('bluetick',), ('black-and-tan coonhound',), ('Walker hound', 'Walker foxhound'), ('English foxhound',), ('redbone',), ('borzoi', 'Russian wolfhound'), ('Irish wolfhound',), ('Italian greyhound',), ('whippet',), ('Ibizan hound', 'Ibizan Podenco'), ('Norwegian elkhound', 'elkhound'), ('otterhound', 'otter hound'), ('Saluki', 'gazelle hound'), ('Scottish deerhound', 'deerhound'), ('Weimaraner',), ('Staffordshire bullterrier', 'Staffordshire bull terrier'), ( 'American Staffordshire terrier', 'Staffordshire terrier', 'American pit bull terrier', 'pit bull terrier'), ('Bedlington terrier',), ('Border terrier',), ('Kerry blue terrier',), ('Irish terrier',), ('Norfolkterrier',), ('Norwich terrier',), ('Yorkshire terrier',), ('wire-haired fox terrier',), ('Lakeland terrier',), ('Sealyham terrier', 'Sealyham'), ('Airedale', 'Airedale terrier'), ('cairn', 'cairn terrier'), ('Australian terrier',), ('Dandie Dinmont', 'Dandie Dinmont terrier'), ('Boston bull', 'Boston terrier'), ('miniature schnauzer',), ('giant schnauzer',), ('standard schnauzer',), ('Scotch terrier', 'Scottish terrier', 'Scottie'), ('Tibetan terrier', 'chrysanthemum dog'), ('silky terrier', 'Sydney silky'), ('soft-coated wheaten terrier',), ('West Highland white terrier',), ('Lhasa', 'Lhasa apso'), ('flat-coated retriever',), ('curly-coated retriever',), ('golden retriever',), ('Labrador retriever',), ('Chesapeake Bay retriever',), ('German short-haired pointer',), ('vizsla', 'Hungarian pointer'), ('English setter',), ('Irish setter', 'red setter'), ('Gordon setter',), ('Brittany spaniel',), ('clumber', 'clumber spaniel'), ('English springer', 'English springer spaniel'), ('Welsh springer spaniel',), ('cocker spaniel', 'English cocker spaniel', 'cocker'), ('Sussex spaniel',), ('Irish water spaniel',), ('kuvasz',), ('schipperke',), ('groenendael',), ('malinois',), ('briard',), ('kelpie',), ('komondor',), ('Old English sheepdog', 'bobtail'), ('Shetland sheepdog', 'Shetland sheep dog', 'Shetland'), ('collie',), ('Border collie',), ('Bouvier des Flandres', 'Bouviers des Flandres'), ('Rottweiler',), ('German shepherd', 'German shepherd dog', 'German police dog', 'alsatian'), ('Doberman', 'Doberman pinscher'), ('miniature pinscher',), ('Greater Swiss Mountain dog',), ('Bernese mountain dog',), ('Appenzeller',), ('EntleBucher',), ('boxer',), ('bull mastiff',), ('Tibetan mastiff',), ('French bulldog',), ('Great Dane',), ('Saint Bernard', 'St Bernard'), ('Eskimo dog', 'husky'), ('malamute', 'malemute', 'Alaskan malamute'), ('Siberian husky',), ('dalmatian', 'coach dog', 'carriage dog'), ('affenpinscher', 'monkey pinscher', 'monkey dog'), ('basenji',), ('pug', 'pug-dog'), ('Leonberg',), ('Newfoundland', 'Newfoundland dog'), ('Great Pyrenees',), ('Samoyed', 'Samoyede'), ('Pomeranian',), ('chow', 'chow chow'), ('keeshond',), ('Brabancon griffon',), ('Pembroke', 'Pembroke Welsh corgi'), ('Cardigan', 'Cardigan Welsh corgi'), ('toy poodle',), ('miniature poodle',), ('standard poodle',), ('Mexican hairless',), ('timber wolf', 'grey wolf', 'gray wolf', 'Canis lupus'), ('white wolf', 'Arctic wolf', 'Canis lupus tundrarum'), ('red wolf', 'maned wolf', 'Canis rufus', 'Canis niger'), ('coyote', 'prairie wolf', 'brush wolf', 'Canis latrans'), ('dingo', 'warrigal', 'warragal', 'Canis dingo'), ('dhole', 'Cuon alpinus'), ('African hunting dog', 'hyena dog', 'Cape hunting dog', 'Lycaon pictus'), ('hyena', 'hyaena'), ('red fox', 'Vulpes vulpes'), ('kit fox', 'Vulpes macrotis'), ('Arctic fox', 'white fox', 'Alopex lagopus'), ('grey fox', 'gray fox', 'Urocyon cinereoargenteus'), ('tabby', 'tabby cat'), ('tiger cat',), ('Persian cat',), ('Siamese cat', 'Siamese'), ('Egyptian cat',), ('cougar', 'puma', 'catamount', 'mountain lion', 'painter', 'panther', 'Felis concolor'), ('lynx', 'catamount'), ('leopard', 'Panthera pardus'), ('snow leopard', 'ounce', 'Panthera uncia'), ('jaguar', 'panther', 'Panthera onca', 'Felis onca'), ('lion', 'king of beasts', 'Panthera leo'), ('tiger', 'Panthera tigris'), ('cheetah', 'chetah', 'Acinonyx jubatus'), ('brown bear', 'bruin', 'Ursus arctos'), ('American black bear', 'black bear', 'Ursus americanus', 'Euarctos americanus'), ('ice bear', 'polar bear', 'Ursus Maritimus', 'Thalarctos maritimus'), ('sloth bear', 'Melursus ursinus', 'Ursus ursinus'), ('mongoose',), ('meerkat', 'mierkat'), ('tiger beetle',), ('ladybug', 'ladybeetle', 'lady beetle', 'ladybird', 'ladybird beetle'), ('ground beetle', 'carabid beetle'), ('long-horned beetle', 'longicorn', 'longicorn beetle'), ('leaf beetle', 'chrysomelid'), ('dung beetle',), ('rhinoceros beetle',), ('weevil',), ('fly',), ('bee',), ('ant', 'emmet', 'pismire'), ('grasshopper', 'hopper'), ('cricket',), ('walking stick', 'walkingstick', 'stick insect'), ('cockroach', 'roach'), ('mantis', 'mantid'), ('cicada', 'cicala'), ('leafhopper',), ('lacewing', 'lacewing fly'), ( 'dragonfly', 'darning needle', "devil's darning needle", 'sewing needle', 'snake feeder', 'snake doctor', 'mosquito hawk', 'skeeter hawk'), ('damselfly',), ('admiral',), ('ringlet', 'ringlet butterfly'), ('monarch', 'monarch butterfly', 'milkweed butterfly', 'Danaus plexippus'), ('cabbage butterfly',), ('sulphur butterfly', 'sulfur butterfly'), ('lycaenid', 'lycaenid butterfly'), ('starfish', 'sea star'), ('sea urchin',), ('sea cucumber', 'holothurian'), ('wood rabbit', 'cottontail', 'cottontail rabbit'), ('hare',), ('Angora', 'Angora rabbit'), ('hamster',), ('porcupine', 'hedgehog'), ('fox squirrel', 'eastern fox squirrel', 'Sciurus niger'), ('marmot',), ('beaver',), ('guinea pig', 'Cavia cobaya'), ('sorrel',), ('zebra',), ('hog', 'pig', 'grunter', 'squealer', 'Sus scrofa'), ('wild boar', 'boar', 'Sus scrofa'), ('warthog',), ('hippopotamus', 'hippo', 'river horse', 'Hippopotamus amphibius'), ('ox',), ('water buffalo', 'water ox', 'Asiatic buffalo', 'Bubalus bubalis'), ('bison',), ('ram', 'tup'), ( 'bighorn', 'bighorn sheep', 'cimarron', 'Rocky Mountain bighorn', 'Rocky Mountain sheep', 'Ovis canadensis'), ('ibex', 'Capra ibex'), ('hartebeest',), ('impala', 'Aepyceros melampus'), ('gazelle',), ('Arabian camel', 'dromedary', 'Camelus dromedarius'), ('llama',), ('weasel',), ('mink',), ('polecat', 'fitch', 'foulmart', 'foumart', 'Mustela putorius'), ('black-footed ferret', 'ferret', 'Mustela nigripes'), ('otter',), ('skunk', 'polecat', 'wood pussy'), ('badger',), ('armadillo',), ('three-toed sloth', 'ai', 'Bradypus tridactylus'), ('orangutan', 'orang', 'orangutang', 'Pongo pygmaeus'), ('gorilla', 'Gorilla gorilla'), ('chimpanzee', 'chimp', 'Pan troglodytes'), ('gibbon', 'Hylobates lar'), ('siamang', 'Hylobates syndactylus', 'Symphalangus syndactylus'), ('guenon', 'guenon monkey'), ('patas', 'hussar monkey', 'Erythrocebus patas'), ('baboon',), ('macaque',), ('langur',), ('colobus', 'colobus monkey'), ('proboscis monkey', 'Nasalis larvatus'), ('marmoset',), ('capuchin', 'ringtail', 'Cebus capucinus'), ('howler monkey', 'howler'), ('titi', 'titi monkey'), ('spider monkey', 'Ateles geoffroyi'), ('squirrel monkey', 'Saimiri sciureus'), ('Madagascar cat', 'ring-tailed lemur', 'Lemur catta'), ('indri', 'indris', 'Indri indri', 'Indri brevicaudatus'), ('Indian elephant', 'Elephas maximus'), ('African elephant', 'Loxodonta africana'), ('lesser panda', 'red panda', 'panda', 'bear cat', 'cat bear', 'Ailurus fulgens'), ('giant panda', 'panda', 'panda bear', 'coon bear', 'Ailuropoda melanoleuca'), ('barracouta', 'snoek'), ('eel',), ('coho', 'cohoe', 'coho salmon', 'blue jack', 'silver salmon', 'Oncorhynchus kisutch'), ('rock beauty', 'Holocanthus tricolor'), ('anemone fish',), ('sturgeon',), ('gar', 'garfish', 'garpike', 'billfish', 'Lepisosteus osseus'), ('lionfish',), ('puffer', 'pufferfish', 'blowfish', 'globefish'), ('abacus',), ('abaya',), ('academic gown', 'academic robe', "judge's robe"), ('accordion', 'piano accordion', 'squeeze box'), ('acoustic guitar',), ('aircraft carrier', 'carrier', 'flattop', 'attack aircraft carrier'), ('airliner',), ('airship', 'dirigible'), ('altar',), ('ambulance',), ('amphibian', 'amphibious vehicle'), ('analog clock',), ('apiary', 'bee house'), ('apron',), ( 'ashcan', 'trash can', 'garbage can', 'wastebin', 'ash bin', 'ash-bin', 'ashbin', 'dustbin', 'trash barrel', 'trash bin'), ('assault rifle', 'assault gun'), ('backpack', 'back pack', 'knapsack', 'packsack', 'rucksack', 'haversack'), ('bakery', 'bakeshop', 'bakehouse'), ('balance beam', 'beam'), ('balloon',), ('ballpoint', 'ballpoint pen', 'ballpen', 'Biro'), ('Band Aid',), ('banjo',), ('bannister', 'banister', 'balustrade', 'balusters', 'handrail'), ('barbell',), ('barber chair',), ('barbershop',), ('barn',), ('barometer',), ('barrel', 'cask'), ('barrow', 'garden cart', 'lawn cart', 'wheelbarrow'), ('baseball',), ('basketball',), ('bassinet',), ('bassoon',), ('bathing cap', 'swimming cap'), ('bath towel',), ('bathtub', 'bathing tub', 'bath', 'tub'), ( 'beach wagon', 'station wagon', 'wagon', 'estate car', 'beach waggon', 'station waggon', 'waggon'), ('beacon', 'lighthouse', 'beacon light', 'pharos'), ('beaker',), ('bearskin', 'busby', 'shako'), ('beer bottle',), ('beer glass',), ('bell cote', 'bell cot'), ('bib',), ('bicycle-built-for-two', 'tandem bicycle', 'tandem'), ('bikini', 'two-piece'), ('binder', 'ring-binder'), ('binoculars', 'field glasses', 'opera glasses'), ('birdhouse',), ('boathouse',), ('bobsled', 'bobsleigh', 'bob'), ('bolo tie', 'bolo', 'bola tie', 'bola'), ('bonnet', 'poke bonnet'), ('bookcase',), ('bookshop', 'bookstore', 'bookstall'), ('bottlecap',), ('bow',), ('bow tie', 'bow-tie', 'bowtie'), ('brass', 'memorial tablet', 'plaque'), ('brassiere', 'bra', 'bandeau'), ('breakwater', 'groin', 'groyne', 'mole', 'bulwark', 'seawall', 'jetty'), ('breastplate', 'aegis', 'egis'), ('broom',), ('bucket', 'pail'), ('buckle',), ('bulletproof vest',), ('bullet train', 'bullet'), ('butcher shop', 'meat market'), ('cab', 'hack', 'taxi', 'taxicab'), ('caldron', 'cauldron'), ('candle', 'taper', 'wax light'), ('cannon',), ('canoe',), ('can opener', 'tin opener'), ('cardigan',), ('car mirror',), ('carousel', 'carrousel', 'merry-go-round', 'roundabout', 'whirligig'), ("carpenter's kit", 'tool kit'), ('carton',), ('car wheel',), ( 'cash machine', 'cash dispenser', 'automated teller machine', 'automatic teller machine', 'automated teller', 'automatic teller', 'ATM'), ('cassette',), ('cassette player',), ('castle',), ('catamaran',), ('CD player',), ('cello', 'violoncello'), ('cellular telephone', 'cellular phone', 'cellphone', 'cell', 'mobile phone'), ('chain',), ('chainlink fence',), ( 'chain mail', 'ring mail', 'mail', 'chain armor', 'chain armour', 'ring armor', 'ring armour'), ('chain saw', 'chainsaw'), ('chest',), ('chiffonier', 'commode'), ('chime', 'bell', 'gong'), ('china cabinet', 'china closet'), ('Christmas stocking',), ('church', 'church building'), ('cinema', 'movie theater', 'movie theatre', 'movie house', 'picture palace'), ('cleaver', 'meat cleaver', 'chopper'), ('cliff dwelling',), ('cloak',), ('clog', 'geta', 'patten', 'sabot'), ('cocktail shaker',), ('coffee mug',), ('coffeepot',), ('coil', 'spiral', 'volute', 'whorl', 'helix'), ('combination lock',), ('computer keyboard', 'keypad'), ('confectionery', 'confectionary', 'candy store'), ('container ship', 'containership', 'container vessel'), ('convertible',), ('corkscrew', 'bottle screw'), ('cornet', 'horn', 'trumpet', 'trump'), ('cowboy boot',), ('cowboy hat', 'ten-gallon hat'), ('cradle',), ('crane',), ('crash helmet',), ('crate',), ('crib', 'cot'), ('Crock Pot',), ('croquet ball',), ('crutch',), ('cuirass',), ('dam', 'dike', 'dyke'), ('desk',), ('desktop computer',), ('dial telephone', 'dial phone'), ('diaper', 'nappy', 'napkin'), ('digital clock',), ('digital watch',), ('dining table', 'board'), ('dishrag', 'dishcloth'), ('dishwasher', 'dish washer', 'dishwashing machine'), ('disk brake', 'disc brake'), ('dock', 'dockage', 'docking facility'), ('dogsled', 'dog sled', 'dog sleigh'), ('dome',), ('doormat', 'welcome mat'), ('drilling platform', 'offshore rig'), ('drum', 'membranophone', 'tympan'), ('drumstick',), ('dumbbell',), ('Dutch oven',), ('electric fan', 'blower'), ('electric guitar',), ('electric locomotive',), ('entertainment center',), ('envelope',), ('espresso maker',), ('face powder',), ('feather boa', 'boa'), ('file', 'file cabinet', 'filing cabinet'), ('fireboat',), ('fire engine', 'fire truck'), ('fire screen', 'fireguard'), ('flagpole', 'flagstaff'), ('flute', 'transverse flute'), ('folding chair',), ('football helmet',), ('forklift',), ('fountain',), ('fountain pen',), ('four-poster',), ('freight car',), ('French horn', 'horn'), ('frying pan', 'frypan', 'skillet'), ('fur coat',), ('garbage truck', 'dustcart'), ('gasmask', 'respirator', 'gas helmet'), ('gas pump', 'gasoline pump', 'petrol pump', 'island dispenser'), ('goblet',), ('go-kart',), ('golf ball',), ('golfcart', 'golf cart'), ('gondola',), ('gong', 'tam-tam'), ('gown',), ('grand piano', 'grand'), ('greenhouse', 'nursery', 'glasshouse'), ('grille', 'radiator grille'), ('grocery store', 'grocery', 'food market', 'market'), ('guillotine',), ('hair slide',), ('hair spray',), ('half track',), ('hammer',), ('hamper',), ('hand blower', 'blow dryer', 'blow drier', 'hair dryer', 'hair drier'), ('hand-held computer', 'hand-held microcomputer'), ('handkerchief', 'hankie', 'hanky', 'hankey'), ('hard disc', 'hard disk', 'fixed disk'), ('harmonica', 'mouth organ', 'harp', 'mouth harp'), ('harp',), ('harvester', 'reaper'), ('hatchet',), ('holster',), ('home theater', 'home theatre'), ('honeycomb',), ('hook', 'claw'), ('hoopskirt', 'crinoline'), ('horizontal bar', 'high bar'), ('horse cart', 'horse-cart'), ('hourglass',), ('iPod',), ('iron', 'smoothing iron'), ("jack-o'-lantern",), ('jean', 'blue jean', 'denim'), ('jeep', 'landrover'), ('jersey', 'T-shirt', 'tee shirt'), ('jigsaw puzzle',), ('jinrikisha', 'ricksha', 'rickshaw'), ('joystick',), ('kimono',), ('knee pad',), ('knot',), ('lab coat', 'laboratory coat'), ('ladle',), ('lampshade', 'lamp shade'), ('laptop', 'laptop computer'), ('lawn mower', 'mower'), ('lens cap', 'lens cover'), ('letter opener', 'paper knife', 'paperknife'), ('library',), ('lifeboat',), ('lighter', 'light', 'igniter', 'ignitor'), ('limousine', 'limo'), ('liner', 'ocean liner'), ('lipstick', 'lip rouge'), ('Loafer',), ('lotion',), ('loudspeaker', 'speaker', 'speaker unit', 'loudspeaker system', 'speaker system'), ('loupe', "jeweler's loupe"), ('lumbermill', 'sawmill'), ('magnetic compass',), ('mailbag', 'postbag'), ('mailbox', 'letter box'), ('maillot',), ('maillot', 'tank suit'), ('manhole cover',), ('maraca',), ('marimba', 'xylophone'), ('mask',), ('matchstick',), ('maypole',), ('maze', 'labyrinth'), ('measuring cup',), ('medicine chest', 'medicine cabinet'), ('megalith', 'megalithic structure'), ('microphone', 'mike'), ('microwave', 'microwave oven'), ('military uniform',), ('milk can',), ('minibus',), ('miniskirt', 'mini'), ('minivan',), ('missile',), ('mitten',), ('mixing bowl',), ('mobile home', 'manufactured home'), ('Model T',), ('modem',), ('monastery',), ('monitor',), ('moped',), ('mortar',), ('mortarboard',), ('mosque',), ('mosquito net',), ('motor scooter', 'scooter'), ('mountain bike', 'all-terrain bike', 'off-roader'), ('mountain tent',), ('mouse', 'computer mouse'), ('mousetrap',), ('moving van',), ('muzzle',), ('nail',), ('neck brace',), ('necklace',), ('nipple',), ('notebook', 'notebook computer'), ('obelisk',), ('oboe', 'hautboy', 'hautbois'), ('ocarina', 'sweet potato'), ('odometer', 'hodometer', 'mileometer', 'milometer'), ('oil filter',), ('organ', 'pipe organ'), ('oscilloscope', 'scope', 'cathode-ray oscilloscope', 'CRO'), ('overskirt',), ('oxcart',), ('oxygen mask',), ('packet',), ('paddle', 'boat paddle'), ('paddlewheel', 'paddle wheel'), ('padlock',), ('paintbrush',), ('pajama', 'pyjama', "pj's", 'jammies'), ('palace',), ('panpipe', 'pandean pipe', 'syrinx'), ('paper towel',), ('parachute', 'chute'), ('parallel bars', 'bars'), ('park bench',), ('parking meter',), ('passenger car', 'coach', 'carriage'), ('patio', 'terrace'), ('pay-phone', 'pay-station'), ('pedestal', 'plinth', 'footstall'), ('pencil box', 'pencil case'), ('pencil sharpener',), ('perfume', 'essence'), ('Petri dish',), ('photocopier',), ('pick', 'plectrum', 'plectron'), ('pickelhaube',), ('picket fence', 'paling'), ('pickup', 'pickup truck'), ('pier',), ('piggy bank', 'penny bank'), ('pill bottle',), ('pillow',), ('ping-pong ball',), ('pinwheel',), ('pirate', 'pirate ship'), ('pitcher', 'ewer'), ('plane', "carpenter's plane", 'woodworking plane'), ('planetarium',), ('plastic bag',), ('plate rack',), ('plow', 'plough'), ('plunger', "plumber's helper"), ('Polaroid camera', 'Polaroid Land camera'), ('pole',), ('police van', 'police wagon', 'paddy wagon', 'patrol wagon', 'wagon', 'black Maria'), ('poncho',), ('pool table', 'billiard table', 'snooker table'), ('pop bottle', 'soda bottle'), ('pot', 'flowerpot'), ("potter's wheel",), ('power drill',), ('prayer rug', 'prayer mat'), ('printer',), ('prison', 'prison house'), ('projectile', 'missile'), ('projector',), ('puck', 'hockey puck'), ('punching bag', 'punch bag', 'punching ball', 'punchball'), ('purse',), ('quill', 'quill pen'), ('quilt', 'comforter', 'comfort', 'puff'), ('racer', 'race car', 'racing car'), ('racket', 'racquet'), ('radiator',), ('radio', 'wireless'), ('radio telescope', 'radio reflector'), ('rain barrel',), ('recreational vehicle', 'RV', 'R.V.'), ('reel',), ('reflex camera',), ('refrigerator', 'icebox'), ('remote control', 'remote'), ('restaurant', 'eating house', 'eating place', 'eatery'), ('revolver', 'six-gun', 'six-shooter'), ('rifle',), ('rocking chair', 'rocker'), ('rotisserie',), ('rubber eraser', 'rubber', 'pencil eraser'), ('rugby ball',), ('rule', 'ruler'), ('running shoe',), ('safe',), ('safety pin',), ('saltshaker', 'salt shaker'), ('sandal',), ('sarong',), ('sax', 'saxophone'), ('scabbard',), ('scale', 'weighing machine'), ('school bus',), ('schooner',), ('scoreboard',), ('screen', 'CRT screen'), ('screw',), ('screwdriver',), ('seat belt', 'seatbelt'), ('sewing machine',), ('shield', 'buckler'), ('shoe shop', 'shoe-shop', 'shoe store'), ('shoji',), ('shopping basket',), ('shopping cart',), ('shovel',), ('shower cap',), ('shower curtain',), ('ski',), ('ski mask',), ('sleeping bag',), ('slide rule', 'slipstick'), ('sliding door',), ('slot', 'one-armed bandit'), ('snorkel',), ('snowmobile',), ('snowplow', 'snowplough'), ('soap dispenser',), ('soccer ball',), ('sock',), ('solar dish', 'solar collector', 'solar furnace'), ('sombrero',), ('soup bowl',), ('space bar',), ('space heater',), ('space shuttle',), ('spatula',), ('speedboat',), ('spider web', "spider's web"), ('spindle',), ('sports car', 'sport car'), ('spotlight', 'spot'), ('stage',), ('steam locomotive',), ('steel arch bridge',), ('steel drum',), ('stethoscope',), ('stole',), ('stone wall',), ('stopwatch', 'stop watch'), ('stove',), ('strainer',), ('streetcar', 'tram', 'tramcar', 'trolley', 'trolley car'), ('stretcher',), ('studio couch', 'day bed'), ('stupa', 'tope'), ('submarine', 'pigboat', 'sub', 'U-boat'), ('suit', 'suit of clothes'), ('sundial',), ('sunglass',), ('sunglasses', 'dark glasses', 'shades'), ('sunscreen', 'sunblock', 'sun blocker'), ('suspension bridge',), ('swab', 'swob', 'mop'), ('sweatshirt',), ('swimming trunks', 'bathing trunks'), ('swing',), ('switch', 'electric switch', 'electrical switch'), ('syringe',), ('table lamp',), ('tank', 'army tank', 'armored combat vehicle', 'armoured combat vehicle'), ('tape player',), ('teapot',), ('teddy', 'teddy bear'), ('television', 'television system'), ('tennis ball',), ('thatch', 'thatched roof'), ('theater curtain', 'theatre curtain'), ('thimble',), ('thresher', 'thrasher', 'threshing machine'), ('throne',), ('tile roof',), ('toaster',), ('tobacco shop', 'tobacconist shop', 'tobacconist'), ('toilet seat',), ('torch',), ('totem pole',), ('tow truck', 'tow car', 'wrecker'), ('toyshop',), ('tractor',), ('trailer truck', 'tractor trailer', 'trucking rig', 'rig', 'articulated lorry', 'semi'), ('tray',), ('trench coat',), ('tricycle', 'trike', 'velocipede'), ('trimaran',), ('tripod',), ('triumphal arch',), ('trolleybus', 'trolley coach', 'trackless trolley'), ('trombone',), ('tub', 'vat'), ('turnstile',), ('typewriter keyboard',), ('umbrella',), ('unicycle', 'monocycle'), ('upright', 'upright piano'), ('vacuum', 'vacuum cleaner'), ('vase',), ('vault',), ('velvet',), ('vending machine',), ('vestment',), ('viaduct',), ('violin', 'fiddle'), ('volleyball',), ('waffle iron',), ('wall clock',), ('wallet', 'billfold', 'notecase', 'pocketbook'), ('wardrobe', 'closet', 'press'), ('warplane', 'military plane'), ('washbasin', 'handbasin', 'washbowl', 'lavabo', 'wash-hand basin'), ('washer', 'automatic washer', 'washing machine'), ('water bottle',), ('water jug',), ('water tower',), ('whiskey jug',), ('whistle',), ('wig',), ('window screen',), ('window shade',), ('Windsor tie',), ('wine bottle',), ('wing',), ('wok',), ('wooden spoon',), ('wool', 'woolen', 'woollen'), ('worm fence', 'snake fence', 'snake-rail fence', 'Virginia fence'), ('wreck',), ('yawl',), ('yurt',), ('web site', 'website', 'internet site', 'site'), ('comic book',), ('crossword puzzle', 'crossword'), ('street sign',), ('traffic light', 'traffic signal', 'stoplight'), ('book jacket', 'dust cover', 'dust jacket', 'dust wrapper'), ('menu',), ('plate',), ('guacamole',), ('consomme',), ('hot pot', 'hotpot'), ('trifle',), ('ice cream', 'icecream'), ('ice lolly', 'lolly', 'lollipop', 'popsicle'), ('French loaf',), ('bagel', 'beigel'), ('pretzel',), ('cheeseburger',), ('hotdog', 'hot dog', 'red hot'), ('mashed potato',), ('head cabbage',), ('broccoli',), ('cauliflower',), ('zucchini', 'courgette'), ('spaghetti squash',), ('acorn squash',), ('butternut squash',), ('cucumber', 'cuke'), ('artichoke', 'globe artichoke'), ('bell pepper',), ('cardoon',), ('mushroom',), ('Granny Smith',), ('strawberry',), ('orange',), ('lemon',), ('fig',), ('pineapple', 'ananas'), ('banana',), ('jackfruit', 'jak', 'jack'), ('custard apple',), ('pomegranate',), ('hay',), ('carbonara',), ('chocolate sauce', 'chocolate syrup'), ('dough',), ('meat loaf', 'meatloaf'), ('pizza', 'pizza pie'), ('potpie',), ('burrito',), ('red wine',), ('espresso',), ('cup',), ('eggnog',), ('alp',), ('bubble',), ('cliff', 'drop', 'drop-off'), ('coral reef',), ('geyser',), ('lakeside', 'lakeshore'), ('promontory', 'headland', 'head', 'foreland'), ('sandbar', 'sand bar'), ('seashore', 'coast', 'seacoast', 'sea-coast'), ('valley', 'vale'), ('volcano',), ('ballplayer', 'baseball player'), ('groom', 'bridegroom'), ('scuba diver',), ('rapeseed',), ('daisy',), ("yellow lady's slipper", 'yellow lady-slipper', 'Cypripedium calceolus', 'Cypripedium parviflorum'), ('corn',), ('acorn',), ('hip', 'rose hip', 'rosehip'), ('buckeye', 'horse chestnut', 'conker'), ('coral fungus',), ('agaric',), ('gyromitra',), ('stinkhorn', 'carrion fungus'), ('earthstar',), ('hen-of-the-woods', 'hen of the woods', 'Polyporus frondosus', 'Grifola frondosa'), ('bolete',), ('ear', 'spike', 'capitulum'), ('toilet tissue', 'toilet paper', 'bathroom tissue')]] class ImageNetCaltech(ImageList): """ImageNet-Caltech is constructed from `Caltech-256 `_ and `ImageNet-1K `_ . They share 84 common classes. ImageNet-Caltech keeps all classes of ImageNet-1K. The label is based on the ImageNet-1K (class 0-999) . The private classes of Caltech-256 is discarded. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \ ``'I'``: ImageNet-1K training set. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically. In `root`, there will exist following files after downloading. :: train/ n01440764/ ... val/ 256_ObjectCategories/ 001.ak47/ ... image_list/ caltech_256_list.txt ... """ image_list = { "I": "image_list/imagenet_train_1000_list.txt", "C": "image_list/caltech_84_list.txt", } CLASSES = _CLASSES def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs): assert task in self.image_list data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), download_list)) if not os.path.exists(os.path.join(root, 'train')): print("Please put train and val directory of ImageNet-1K manually under {} " "since ImageNet-1K is no longer publicly accessible.".format(root)) exit(-1) super(ImageNetCaltech, self).__init__(root, ImageNetCaltech.CLASSES, data_list_file=data_list_file, **kwargs) class ImageNetCaltechUniversal(ImageList): """ImageNet-Caltech-Universal is constructed from `Caltech-256 `_ and `ImageNet-1K `_ . They share 84 common classes. ImageNet-Caltech keeps all classes of ImageNet-1K. The label is based on the ImageNet-1K (class 0-999) . The private classes of Caltech-256 is grouped into class 1000 ("unknown"). Thus, ImageNetCaltechUniversal has 1001 classes in total. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'C'``:Caltech-256, \ ``'I'``: ImageNet-1K training set. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: You need to put ``train`` and ``val`` directory of ImageNet-1K manually in `root` directory since ImageNet-1K is no longer publicly accessible. DALIB will only download Caltech-256 and ImageList automatically. In `root`, there will exist following files after downloading. :: train/ n01440764/ ... val/ 256_ObjectCategories/ 001.ak47/ ... image_list/ caltech_256_list.txt ... """ image_list = { "I": "image_list/imagenet_train_1000_list.txt", "C": "image_list/caltech_85_list.txt", } CLASSES = _CLASSES + ["unknown"] def __init__(self, root: str, task: str, download: Optional[bool] = True, **kwargs): assert task in self.image_list data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), download_list)) if not os.path.exists(os.path.join(root, 'train')): print("Please put train and val directory of ImageNet-1K manually under {} " "since ImageNet-1K is no longer publicly accessible.".format(root)) exit(-1) super(ImageNetCaltechUniversal, self).__init__(root, ImageNetCaltechUniversal.CLASSES, data_list_file=data_list_file, **kwargs) download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/a0d7ea37026946f98965/?dl=1"), ("256_ObjectCategories", "256_ObjectCategories.tar", "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar"), ] ================================================ FILE: tllib/vision/datasets/patchcamelyon.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from .imagelist import ImageList from ._util import download as download_data, check_exits class PatchCamelyon(ImageList): """ The `PatchCamelyon `_ dataset contains \ 327680 images of histopathologic scans of lymph node sections. \ The classification task consists in predicting the presence of metastatic tissue \ in given image (i.e., two classes). All images are 96x96 pixels Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ CLASSES = ['0', '1'] def __init__(self, root, split, download=False, **kwargs): if download: download_data(root, "patch_camelyon", "patch_camelyon.tgz", "https://cloud.tsinghua.edu.cn/f/21360b3441a54274b843/?dl=1") else: check_exits(root, "patch_camelyon") root = os.path.join(root, "patch_camelyon") super(PatchCamelyon, self).__init__(root, PatchCamelyon.CLASSES, os.path.join(root, "imagelist", "{}.txt".format(split)), **kwargs) ================================================ FILE: tllib/vision/datasets/regression/__init__.py ================================================ from .image_regression import ImageRegression from .dsprites import DSprites from .mpi3d import MPI3D ================================================ FILE: tllib/vision/datasets/regression/dsprites.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, Sequence import os from .._util import download as download_data, check_exits from .image_regression import ImageRegression class DSprites(ImageRegression): """`DSprites `_ Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \ ``'N'``: Noisy and ``'S'``: Scream. split (str, optional): The dataset split, supports ``train``, or ``test``. factors (sequence[str]): Factors selected. Default: ('scale', 'position x', 'position y'). download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: color/ ... noisy/ scream/ image_list/ color_train.txt noisy_train.txt scream_train.txt color_test.txt noisy_test.txt scream_test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/4392ef903ed14017a042/?dl=1"), ("color", "color.tgz", "https://cloud.tsinghua.edu.cn/f/6d243c589d384ff5a212/?dl=1"), ("noisy", "noisy.tgz", "https://cloud.tsinghua.edu.cn/f/9a23ede3be1740328637/?dl=1"), ("scream", "scream.tgz", "https://cloud.tsinghua.edu.cn/f/8fc4d34311bb4db6bcde/?dl=1"), ] image_list = { "C": "color", "N": "noisy", "S": "scream" } FACTORS = ('none', 'shape', 'scale', 'orientation', 'position x', 'position y') def __init__(self, root: str, task: str, split: Optional[str] = 'train', factors: Sequence[str] = ('scale', 'position x', 'position y'), download: Optional[bool] = True, target_transform=None, **kwargs): assert task in self.image_list assert split in ['train', 'test'] for factor in factors: assert factor in self.FACTORS factor_index = [self.FACTORS.index(factor) for factor in factors] if target_transform is None: target_transform = lambda x: x[list(factor_index)] else: target_transform = lambda x: target_transform(x[list(factor_index)]) data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split)) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(DSprites, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform, **kwargs) ================================================ FILE: tllib/vision/datasets/regression/image_regression.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from typing import Optional, Callable, Tuple, Any, List, Sequence import torchvision.datasets as datasets from torchvision.datasets.folder import default_loader import numpy as np class ImageRegression(datasets.VisionDataset): """A generic Dataset class for domain adaptation in image regression Args: root (str): Root directory of dataset factors (sequence[str]): Factors selected. Default: ('scale', 'position x', 'position y'). data_list_file (str): File to read the image list from. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `data_list_file`, each line has `1+len(factors)` values in the following format. :: source_dir/dog_xxx.png x11, x12, ... source_dir/cat_123.png x21, x22, ... target_dir/dog_xxy.png x31, x32, ... target_dir/cat_nsdf3.png x41, x42, ... The first value is the relative path of an image, and the rest values are the ground truth of the corresponding factors. If your data_list_file has different formats, please over-ride :meth:`ImageRegression.parse_data_file`. """ def __init__(self, root: str, factors: Sequence[str], data_list_file: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): super().__init__(root, transform=transform, target_transform=target_transform) self.samples = self.parse_data_file(data_list_file) self.factors = factors self.loader = default_loader self.data_list_file = data_list_file def __getitem__(self, index: int) -> Tuple[Any, Tuple[float]]: """ Args: index (int): Index Returns: (image, target) where target is a numpy float array. """ path, target = self.samples[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None and target is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.samples) def parse_data_file(self, file_name: str) -> List[Tuple[str, Any]]: """Parse file to data list Args: file_name (str): The path of data file Returns: List of (image path, (factors)) tuples """ with open(file_name, "r") as f: data_list = [] for line in f.readlines(): data = line.split() path = str(data[0]) target = np.array([float(d) for d in data[1:]], dtype=np.float) if not os.path.isabs(path): path = os.path.join(self.root, path) data_list.append((path, target)) return data_list @property def num_factors(self) -> int: return len(self.factors) ================================================ FILE: tllib/vision/datasets/regression/mpi3d.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Optional, Sequence import os from .._util import download as download_data, check_exits from .image_regression import ImageRegression class MPI3D(ImageRegression): """`MPI3D `_ Dataset. Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'C'``: Color, \ ``'N'``: Noisy and ``'S'``: Scream. split (str, optional): The dataset split, supports ``train``, or ``test``. factors (sequence[str]): Factors selected. Default: ('horizontal axis', 'vertical axis'). download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: real/ ... realistic/ toy/ image_list/ real_train.txt realistic_train.txt toy_train.txt real_test.txt realistic_test.txt toy_test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/feacec494d5347b7a6aa/?dl=1"), ("real", "real.tgz", "https://cloud.tsinghua.edu.cn/f/605dd842cd9d4071a0ae/?dl=1"), ("realistic", "realistic.tgz", "https://cloud.tsinghua.edu.cn/f/05743f3071054cc29e25/?dl=1"), ("toy", "toy.tgz", "https://cloud.tsinghua.edu.cn/f/1511dff7853d4abea38f/?dl=1"), ] image_list = { "RL": "real", "RC": "realistic", "T": "toy" } FACTORS = ('horizontal axis', 'vertical axis') def __init__(self, root: str, task: str, split: Optional[str] = 'train', factors: Sequence[str] = ('horizontal axis', 'vertical axis'), download: Optional[bool] = True, target_transform=None, **kwargs): assert task in self.image_list assert split in ['train', 'test'] for factor in factors: assert factor in self.FACTORS factor_index = [self.FACTORS.index(factor) for factor in factors] if target_transform is None: target_transform = lambda x: x[list(factor_index)] / 39. else: target_transform = lambda x: target_transform(x[list(factor_index)]) / 39. data_list_file = os.path.join(root, "image_list", "{}_{}.txt".format(self.image_list[task], split)) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(MPI3D, self).__init__(root, factors, data_list_file=data_list_file, target_transform=target_transform, **kwargs) ================================================ FILE: tllib/vision/datasets/reid/__init__.py ================================================ from .market1501 import Market1501 from .dukemtmc import DukeMTMC from .msmt17 import MSMT17 from .personx import PersonX from .unreal import UnrealPerson __all__ = ['Market1501', 'DukeMTMC', 'MSMT17', 'PersonX', 'UnrealPerson'] ================================================ FILE: tllib/vision/datasets/reid/basedataset.py ================================================ """ Modified from https://github.com/yxgeee/MMT @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import os.path as osp import numpy as np class BaseDataset(object): """ Base class of reid dataset """ def get_imagedata_info(self, data): pids, cams = [], [] for _, pid, camid in data: pids += [pid] cams += [camid] pids = set(pids) cams = set(cams) num_pids = len(pids) num_cams = len(cams) num_imgs = len(data) return num_pids, num_imgs, num_cams def get_videodata_info(self, data, return_tracklet_stats=False): pids, cams, tracklet_stats = [], [], [] for img_paths, pid, camid in data: pids += [pid] cams += [camid] tracklet_stats += [len(img_paths)] pids = set(pids) cams = set(cams) num_pids = len(pids) num_cams = len(cams) num_tracklets = len(data) if return_tracklet_stats: return num_pids, num_tracklets, num_cams, tracklet_stats return num_pids, num_tracklets, num_cams def print_dataset_statistics(self, train, query, galler): raise NotImplementedError def check_before_run(self, required_files): """Checks if required files exist before going deeper. Args: required_files (str or list): string file name(s). """ if isinstance(required_files, str): required_files = [required_files] for fpath in required_files: if not osp.exists(fpath): raise RuntimeError('"{}" is not found'.format(fpath)) @property def images_dir(self): return None class BaseImageDataset(BaseDataset): """ Base class of image reid dataset """ def print_dataset_statistics(self, train, query, gallery): num_train_pids, num_train_imgs, num_train_cams = self.get_imagedata_info(train) num_query_pids, num_query_imgs, num_query_cams = self.get_imagedata_info(query) num_gallery_pids, num_gallery_imgs, num_gallery_cams = self.get_imagedata_info(gallery) print("Dataset statistics:") print(" ----------------------------------------") print(" subset | # ids | # images | # cameras") print(" ----------------------------------------") print(" train | {:5d} | {:8d} | {:9d}".format(num_train_pids, num_train_imgs, num_train_cams)) print(" query | {:5d} | {:8d} | {:9d}".format(num_query_pids, num_query_imgs, num_query_cams)) print(" gallery | {:5d} | {:8d} | {:9d}".format(num_gallery_pids, num_gallery_imgs, num_gallery_cams)) print(" ----------------------------------------") class BaseVideoDataset(BaseDataset): """ Base class of video reid dataset """ def print_dataset_statistics(self, train, query, gallery): num_train_pids, num_train_tracklets, num_train_cams, train_tracklet_stats = \ self.get_videodata_info(train, return_tracklet_stats=True) num_query_pids, num_query_tracklets, num_query_cams, query_tracklet_stats = \ self.get_videodata_info(query, return_tracklet_stats=True) num_gallery_pids, num_gallery_tracklets, num_gallery_cams, gallery_tracklet_stats = \ self.get_videodata_info(gallery, return_tracklet_stats=True) tracklet_stats = train_tracklet_stats + query_tracklet_stats + gallery_tracklet_stats min_num = np.min(tracklet_stats) max_num = np.max(tracklet_stats) avg_num = np.mean(tracklet_stats) print("Dataset statistics:") print(" -------------------------------------------") print(" subset | # ids | # tracklets | # cameras") print(" -------------------------------------------") print(" train | {:5d} | {:11d} | {:9d}".format(num_train_pids, num_train_tracklets, num_train_cams)) print(" query | {:5d} | {:11d} | {:9d}".format(num_query_pids, num_query_tracklets, num_query_cams)) print(" gallery | {:5d} | {:11d} | {:9d}".format(num_gallery_pids, num_gallery_tracklets, num_gallery_cams)) print(" -------------------------------------------") print(" number of images per tracklet: {} ~ {}, average {:.2f}".format(min_num, max_num, avg_num)) print(" -------------------------------------------") ================================================ FILE: tllib/vision/datasets/reid/convert.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import os.path as osp from torch.utils.data import Dataset from PIL import Image def convert_to_pytorch_dataset(dataset, root=None, transform=None, return_idxes=False): class ReidDataset(Dataset): def __init__(self, dataset, root, transform): super(ReidDataset, self).__init__() self.dataset = dataset self.root = root self.transform = transform self.return_idxes = return_idxes def __len__(self): return len(self.dataset) def __getitem__(self, index): fname, pid, cid = self.dataset[index] fpath = fname if self.root is not None: fpath = osp.join(self.root, fname) img = Image.open(fpath).convert('RGB') if self.transform is not None: img = self.transform(img) if not self.return_idxes: return img, fname, pid, cid else: return img, fname, pid, cid, index return ReidDataset(dataset, root, transform) ================================================ FILE: tllib/vision/datasets/reid/dukemtmc.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from .basedataset import BaseImageDataset from typing import Callable from PIL import Image import os import os.path as osp import glob import re from tllib.vision.datasets._util import download class DukeMTMC(BaseImageDataset): """DukeMTMC-reID dataset from `Performance Measures and a Data Set for Multi-Target, Multi-Camera Tracking (ECCV 2016) `_. Dataset statistics: - identities: 1404 (train + query) - images:16522 (train) + 2228 (query) + 17661 (gallery) - cameras: 8 Args: root (str): Root directory of dataset verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True """ dataset_dir = '.' archive_name = 'DukeMTMC-reID.tgz' dataset_url = 'https://cloud.tsinghua.edu.cn/f/cb80f49905ee4e8eb9f0/?dl=1' def __init__(self, root, verbose=True): super(DukeMTMC, self).__init__() download(root, self.dataset_dir, self.archive_name, self.dataset_url) self.relative_dataset_dir = self.dataset_dir self.dataset_dir = osp.join(root, self.dataset_dir) self.train_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_train') self.query_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/query') self.gallery_dir = osp.join(self.dataset_dir, 'DukeMTMC-reID/bounding_box_test') required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir] self.check_before_run(required_files) train = self.process_dir(self.train_dir, relabel=True) query = self.process_dir(self.query_dir, relabel=False) gallery = self.process_dir(self.gallery_dir, relabel=False) if verbose: print("=> DukeMTMC-reID loaded") self.print_dataset_statistics(train, query, gallery) self.train = train self.query = query self.gallery = gallery self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) def process_dir(self, dir_path, relabel=False): img_paths = glob.glob(osp.join(dir_path, '*.jpg')) pattern = re.compile(r'([-\d]+)_c(\d)') pid_container = set() for img_path in img_paths: pid, _ = map(int, pattern.search(img_path).groups()) pid_container.add(pid) pid2label = {pid: label for label, pid in enumerate(pid_container)} dataset = [] for img_path in img_paths: pid, cid = map(int, pattern.search(img_path).groups()) assert 1 <= cid <= 8 cid -= 1 # index starts from 0 if relabel: pid = pid2label[pid] dataset.append((img_path, pid, cid)) return dataset def translate(self, transform: Callable, target_root: str): """ Translate an image and save it into a specified directory Args: transform (callable): a transform function that maps images from one domain to another domain target_root (str): the root directory to save images """ os.makedirs(target_root, exist_ok=True) translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir) translated_train_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/bounding_box_train') translated_query_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/query') translated_gallery_dir = osp.join(translated_dataset_dir, 'DukeMTMC-reID/bounding_box_test') print("Translating dataset with image to image transform...") self.translate_dir(transform, self.train_dir, translated_train_dir) self.translate_dir(None, self.query_dir, translated_query_dir) self.translate_dir(None, self.gallery_dir, translated_gallery_dir) print("Translation process is done, save dataset to {}".format(translated_dataset_dir)) def translate_dir(self, transform, origin_dir: str, target_dir: str): image_list = os.listdir(origin_dir) for image_name in image_list: if not image_name.endswith(".jpg"): continue image_path = osp.join(origin_dir, image_name) image = Image.open(image_path) translated_image_path = osp.join(target_dir, image_name) translated_image = image if transform: translated_image = transform(image) os.makedirs(os.path.dirname(translated_image_path), exist_ok=True) translated_image.save(translated_image_path) ================================================ FILE: tllib/vision/datasets/reid/market1501.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from .basedataset import BaseImageDataset from typing import Callable from PIL import Image import os import os.path as osp import glob import re from tllib.vision.datasets._util import download class Market1501(BaseImageDataset): """Market1501 dataset from `Scalable Person Re-identification: A Benchmark (ICCV 2015) `_. Dataset statistics: - identities: 1501 (+1 for background) - images: 12936 (train) + 3368 (query) + 15913 (gallery) - cameras: 6 Args: root (str): Root directory of dataset verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True """ dataset_dir = 'Market-1501-v15.09.15' archive_name = 'Market-1501-v15.09.15.tgz' dataset_url = 'https://cloud.tsinghua.edu.cn/f/29e5f015a7314531b645/?dl=1' def __init__(self, root, verbose=True): super(Market1501, self).__init__() download(root, self.dataset_dir, self.archive_name, self.dataset_url) self.relative_dataset_dir = self.dataset_dir self.dataset_dir = osp.join(root, self.dataset_dir) self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') self.query_dir = osp.join(self.dataset_dir, 'query') self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir] self.check_before_run(required_files) train = self.process_dir(self.train_dir, relabel=True) query = self.process_dir(self.query_dir, relabel=False) gallery = self.process_dir(self.gallery_dir, relabel=False) if verbose: print("=> Market1501 loaded") self.print_dataset_statistics(train, query, gallery) self.train = train self.query = query self.gallery = gallery self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) def process_dir(self, dir_path, relabel=False): img_paths = glob.glob(osp.join(dir_path, '*.jpg')) pattern = re.compile(r'([-\d]+)_c(\d)') pid_container = set() for img_path in img_paths: pid, _ = map(int, pattern.search(img_path).groups()) if pid == -1: continue # junk images are just ignored pid_container.add(pid) pid2label = {pid: label for label, pid in enumerate(pid_container)} dataset = [] for img_path in img_paths: pid, cid = map(int, pattern.search(img_path).groups()) if pid == -1: continue # junk images are just ignored assert 0 <= pid <= 1501 # pid == 0 means background assert 1 <= cid <= 6 cid -= 1 # index starts from 0 if relabel: pid = pid2label[pid] dataset.append((img_path, pid, cid)) return dataset def translate(self, transform: Callable, target_root: str): """ Translate an image and save it into a specified directory Args: transform (callable): a transform function that maps images from one domain to another domain target_root (str): the root directory to save images """ os.makedirs(target_root, exist_ok=True) translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir) translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train') translated_query_dir = osp.join(translated_dataset_dir, 'query') translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test') print("Translating dataset with image to image transform...") self.translate_dir(transform, self.train_dir, translated_train_dir) self.translate_dir(None, self.query_dir, translated_query_dir) self.translate_dir(None, self.gallery_dir, translated_gallery_dir) print("Translation process is done, save dataset to {}".format(translated_dataset_dir)) def translate_dir(self, transform, origin_dir: str, target_dir: str): image_list = os.listdir(origin_dir) for image_name in image_list: if not image_name.endswith(".jpg"): continue image_path = osp.join(origin_dir, image_name) image = Image.open(image_path) translated_image_path = osp.join(target_dir, image_name) translated_image = image if transform: translated_image = transform(image) os.makedirs(os.path.dirname(translated_image_path), exist_ok=True) translated_image.save(translated_image_path) ================================================ FILE: tllib/vision/datasets/reid/msmt17.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from .basedataset import BaseImageDataset from typing import Callable from PIL import Image import os import os.path as osp from tllib.vision.datasets._util import download class MSMT17(BaseImageDataset): """MSMT17 dataset from `Person Transfer GAN to Bridge Domain Gap for Person Re-Identification (CVPR 2018) `_. Dataset statistics: - identities: 4101 - images: 32621 (train) + 11659 (query) + 82161 (gallery) - cameras: 15 Args: root (str): Root directory of dataset verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True """ dataset_dir = '.' archive_name = 'MSMT17_V1.zip' dataset_url = 'https://cloud.tsinghua.edu.cn/f/c254ea490cfa4115940d/?dl=1' def __init__(self, root, verbose=True): super(MSMT17, self).__init__() download(root, self.dataset_dir, self.archive_name, self.dataset_url) self.relative_dataset_dir = self.dataset_dir self.dataset_dir = osp.join(root, self.dataset_dir) self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') self.query_dir = osp.join(self.dataset_dir, 'query') self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir] self.check_before_run(required_files) self.train = self.process_dir(self.train_dir) self.query = self.process_dir(self.query_dir) self.gallery = self.process_dir(self.gallery_dir) if verbose: print("=> MSMT17 loaded") self.print_dataset_statistics(self.train, self.query, self.gallery) self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) def process_dir(self, dir_path): image_list = os.listdir(dir_path) dataset = [] pid_container = set() for image_path in image_list: pid, cid, _ = image_path.split('_') pid = int(pid) cid = int(cid[1:]) - 1 # index starts from 0 full_image_path = osp.join(dir_path, image_path) dataset.append((full_image_path, pid, cid)) pid_container.add(pid) # check if pid starts from 0 and increments with 1 for idx, pid in enumerate(pid_container): assert idx == pid, "See code comment for explanation" return dataset def translate(self, transform: Callable, target_root: str): """ Translate an image and save it into a specified directory Args: transform (callable): a transform function that maps images from one domain to another domain target_root (str): the root directory to save images """ os.makedirs(target_root, exist_ok=True) translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir) translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train') translated_query_dir = osp.join(translated_dataset_dir, 'query') translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test') print("Translating dataset with image to image transform...") self.translate_dir(transform, self.train_dir, translated_train_dir) self.translate_dir(None, self.query_dir, translated_query_dir) self.translate_dir(None, self.gallery_dir, translated_gallery_dir) print("Translation process is done, save dataset to {}".format(translated_dataset_dir)) def translate_dir(self, transform, origin_dir: str, target_dir: str): image_list = os.listdir(origin_dir) for image_name in image_list: if not image_name.endswith(".jpg"): continue image_path = osp.join(origin_dir, image_name) image = Image.open(image_path) translated_image_path = osp.join(target_dir, image_name) translated_image = image if transform: translated_image = transform(image) os.makedirs(os.path.dirname(translated_image_path), exist_ok=True) translated_image.save(translated_image_path) ================================================ FILE: tllib/vision/datasets/reid/personx.py ================================================ """ Modified from https://github.com/yxgeee/SpCL @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from .basedataset import BaseImageDataset from typing import Callable from PIL import Image import os import os.path as osp import glob import re from tllib.vision.datasets._util import download class PersonX(BaseImageDataset): """PersonX dataset from `Dissecting Person Re-identification from the Viewpoint of Viewpoint (CVPR 2019) `_. Dataset statistics: - identities: 1266 - images: 9840 (train) + 5136 (query) + 30816 (gallery) - cameras: 6 Args: root (str): Root directory of dataset verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True """ dataset_dir = '.' archive_name = 'PersonX.zip' dataset_url = 'https://cloud.tsinghua.edu.cn/f/f506cd11d6b646729bd1/?dl=1' def __init__(self, root, verbose=True): super(PersonX, self).__init__() download(root, self.dataset_dir, self.archive_name, self.dataset_url) self.relative_dataset_dir = self.dataset_dir self.dataset_dir = osp.join(root, self.dataset_dir) self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train') self.query_dir = osp.join(self.dataset_dir, 'query') self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test') required_files = [self.dataset_dir, self.train_dir, self.query_dir, self.gallery_dir] self.check_before_run(required_files) train = self.process_dir(self.train_dir, relabel=True) query = self.process_dir(self.query_dir, relabel=False) gallery = self.process_dir(self.gallery_dir, relabel=False) if verbose: print("=> PersonX loaded") self.print_dataset_statistics(train, query, gallery) self.train = train self.query = query self.gallery = gallery self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) self.num_query_pids, self.num_query_imgs, self.num_query_cams = self.get_imagedata_info(self.query) self.num_gallery_pids, self.num_gallery_imgs, self.num_gallery_cams = self.get_imagedata_info(self.gallery) def process_dir(self, dir_path, relabel=False): img_paths = glob.glob(osp.join(dir_path, '*.jpg')) pattern = re.compile(r'([-\d]+)_c([-\d]+)') cam2label = {3: 1, 4: 2, 8: 3, 10: 4, 11: 5, 12: 6} pid_container = set() for img_path in img_paths: pid, _ = map(int, pattern.search(img_path).groups()) pid_container.add(pid) pid2label = {pid: label for label, pid in enumerate(pid_container)} dataset = [] for img_path in img_paths: pid, cid = map(int, pattern.search(img_path).groups()) assert (cid in cam2label.keys()) cid = cam2label[cid] cid -= 1 # index starts from 0 if relabel: pid = pid2label[pid] dataset.append((img_path, pid, cid)) return dataset def translate(self, transform: Callable, target_root: str): """ Translate an image and save it into a specified directory Args: transform (callable): a transform function that maps images from one domain to another domain target_root (str): the root directory to save images """ os.makedirs(target_root, exist_ok=True) translated_dataset_dir = osp.join(target_root, self.relative_dataset_dir) translated_train_dir = osp.join(translated_dataset_dir, 'bounding_box_train') translated_query_dir = osp.join(translated_dataset_dir, 'query') translated_gallery_dir = osp.join(translated_dataset_dir, 'bounding_box_test') print("Translating dataset with image to image transform...") self.translate_dir(transform, self.train_dir, translated_train_dir) self.translate_dir(None, self.query_dir, translated_query_dir) self.translate_dir(None, self.gallery_dir, translated_gallery_dir) print("Translation process is done, save dataset to {}".format(translated_dataset_dir)) def translate_dir(self, transform, origin_dir: str, target_dir: str): image_list = os.listdir(origin_dir) for image_name in image_list: if not image_name.endswith(".jpg"): continue image_path = osp.join(origin_dir, image_name) image = Image.open(image_path) translated_image_path = osp.join(target_dir, image_name) translated_image = image if transform: translated_image = transform(image) os.makedirs(os.path.dirname(translated_image_path), exist_ok=True) translated_image.save(translated_image_path) ================================================ FILE: tllib/vision/datasets/reid/unreal.py ================================================ """ Modified from https://github.com/SikaStar/IDM @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from .basedataset import BaseImageDataset from typing import Callable import os.path as osp from tllib.vision.datasets._util import download class UnrealPerson(BaseImageDataset): """UnrealPerson dataset from `UnrealPerson: An Adaptive Pipeline towards Costless Person Re-identification (CVPR 2021) `_. Dataset statistics: - identities: 3000 - images: 120,000 - cameras: 34 Args: root (str): Root directory of dataset verbose (bool, optional): If true, print dataset statistics after loading the dataset. Default: True """ dataset_dir = '.' download_list = [ ("list_unreal_train.txt", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/a51b22fd760743e7bca6/?dl=1"), ("unreal_v1.1", "unreal_v1.1.tar", "https://cloud.tsinghua.edu.cn/f/a8806bb3bf1744dda5b1/?dl=1"), ("unreal_v1.2", "unreal_v1.2.tar", "https://cloud.tsinghua.edu.cn/f/449224485a654c5baa8f/?dl=1"), ("unreal_v1.3", "unreal_v1.3.tar", "https://cloud.tsinghua.edu.cn/f/069f3162f74849c09c10/?dl=1"), ("unreal_v2.1", "unreal_v2.1.tar", "https://cloud.tsinghua.edu.cn/f/a791aaa42674466eb183/?dl=1"), ("unreal_v2.2", "unreal_v2.2.tar", "https://cloud.tsinghua.edu.cn/f/b601d9f54f964248bd0e/?dl=1"), ("unreal_v2.3", "unreal_v2.3.tar", "https://cloud.tsinghua.edu.cn/f/311ec60e810b42d48d12/?dl=1"), ("unreal_v3.1", "unreal_v3.1.tar", "https://cloud.tsinghua.edu.cn/f/d51b7c1d125e4632bcf9/?dl=1"), ("unreal_v3.2", "unreal_v3.2.tar", "https://cloud.tsinghua.edu.cn/f/4efbd969ea2f4e8197e8/?dl=1"), ("unreal_v3.3", "unreal_v3.3.tar", "https://cloud.tsinghua.edu.cn/f/a3cc3d9c460247848fb7/?dl=1"), ("unreal_v4.1", "unreal_v4.1.tar", "https://cloud.tsinghua.edu.cn/f/ca05183ac9cd4be5a53b/?dl=1"), ("unreal_v4.2", "unreal_v4.2.tar", "https://cloud.tsinghua.edu.cn/f/b90722cbd754496f9f40/?dl=1"), ("unreal_v4.3", "unreal_v4.3.tar", "https://cloud.tsinghua.edu.cn/f/547ae646c3d346038297/?dl=1"), ] def __init__(self, root, verbose=True): super(UnrealPerson, self).__init__() list(map(lambda args: download(root, *args), self.download_list)) self.dataset_dir = osp.join(root, self.dataset_dir) self.train_list = osp.join(self.dataset_dir, 'list_unreal_train.txt') required_files = [self.dataset_dir, self.train_list] self.check_before_run(required_files) train = self.process_dir(self.train_list) self.train = train self.query = [] self.gallery = [] self.num_train_pids, self.num_train_imgs, self.num_train_cams = self.get_imagedata_info(self.train) if verbose: print("=> UnrealPerson loaded") print(" ----------------------------------------") print(" subset | # ids | # cams | # images") print(" ----------------------------------------") print(" train | {:5d} | {:5d} | {:8d}" .format(self.num_train_pids, self.num_train_cams, self.num_train_imgs)) print(" ----------------------------------------") def process_dir(self, list_file): with open(list_file, 'r') as f: lines = f.readlines() dataset = [] pid_container = set() for line in lines: line = line.strip() pid = line.split(' ')[1] pid_container.add(pid) pid2label = {pid: label for label, pid in enumerate(sorted(pid_container))} for line in lines: line = line.strip() fname, pid, cid = line.split(' ')[0], line.split(' ')[1], int(line.split(' ')[2]) img_path = osp.join(self.dataset_dir, fname) dataset.append((img_path, pid2label[pid], cid)) return dataset def translate(self, transform: Callable, target_root: str): raise NotImplementedError ================================================ FILE: tllib/vision/datasets/resisc45.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from torchvision.datasets.folder import ImageFolder import random class Resisc45(ImageFolder): """`Resisc45 `_ dataset \ is a scene classification task from remote sensing images. There are 45 classes, \ containing 700 images each, including tennis court, ship, island, lake, \ parking lot, sparse residential, or stadium. \ The image size is RGB 256x256 pixels. .. note:: You need to download the source data manually into `root` directory. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ def __init__(self, root, split='train', download=False, **kwargs): super(Resisc45, self).__init__(root, **kwargs) random.seed(0) random.shuffle(self.samples) if split == 'train': self.samples = self.samples[:25200] else: self.samples = self.samples[25200:] @property def num_classes(self) -> int: """Number of classes""" return len(self.classes) ================================================ FILE: tllib/vision/datasets/retinopathy.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from .imagelist import ImageList class Retinopathy(ImageList): """`Retinopathy `_ dataset \ consists of image-label pairs with high-resolution retina images, and labels that indicate \ the presence of Diabetic Retinopahy (DR) in a 0-4 scale (No DR, Mild, Moderate, Severe, \ or Proliferative DR). .. note:: You need to download the source data manually into `root` directory. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ CLASSES = ['No DR', 'Mild', 'Moderate', 'Severe', 'Proliferative DR'] def __init__(self, root, split, download=False, **kwargs): super(Retinopathy, self).__init__(os.path.join(root, split), Retinopathy.CLASSES, os.path.join(root, "image_list", "{}.txt".format(split)), **kwargs) ================================================ FILE: tllib/vision/datasets/segmentation/__init__.py ================================================ from .segmentation_list import SegmentationList from .cityscapes import Cityscapes, FoggyCityscapes from .gta5 import GTA5 from .synthia import Synthia __all__ = ["SegmentationList", "Cityscapes", "GTA5", "Synthia", "FoggyCityscapes"] ================================================ FILE: tllib/vision/datasets/segmentation/cityscapes.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from .segmentation_list import SegmentationList from .._util import download as download_data class Cityscapes(SegmentationList): """`Cityscapes `_ is a real-world semantic segmentation dataset collected in driving scenarios. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``val``. data_folder (str, optional): Sub-directory of the image. Default: 'leftImg8bit'. label_folder (str, optional): Sub-directory of the label. Default: 'gtFine'. mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None. transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \ and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`. .. note:: You need to download Cityscapes manually. Ensure that there exist following files in the `root` directory before you using this class. :: leftImg8bit/ train/ val/ test/ gtFine/ train/ val/ test/ """ CLASSES = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 'bicycle'] ID_TO_TRAIN_ID = { 7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5, 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18 } TRAIN_ID_TO_COLOR = [(128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32), [0, 0, 0]] download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/a23536bb8a724e91af39/?dl=1"), ] EVALUATE_CLASSES = CLASSES def __init__(self, root, split='train', data_folder='leftImg8bit', label_folder='gtFine', **kwargs): assert split in ['train', 'val'] # download meta information from Internet list(map(lambda args: download_data(root, *args), self.download_list)) data_list_file = os.path.join(root, "image_list", "{}.txt".format(split)) self.split = split super(Cityscapes, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, os.path.join(data_folder, split), os.path.join(label_folder, split), id_to_train_id=Cityscapes.ID_TO_TRAIN_ID, train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs) def parse_label_file(self, label_list_file): with open(label_list_file, "r") as f: label_list = [line.strip().replace("leftImg8bit", "gtFine_labelIds") for line in f.readlines()] return label_list class FoggyCityscapes(Cityscapes): """`Foggy Cityscapes `_ is a real-world semantic segmentation dataset collected in foggy driving scenarios. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``val``. data_folder (str, optional): Sub-directory of the image. Default: 'leftImg8bit'. label_folder (str, optional): Sub-directory of the label. Default: 'gtFine'. beta (float, optional): The parameter for foggy. Choices includes: 0.005, 0.01, 0.02. Default: 0.02 mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None. transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \ and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`. .. note:: You need to download Cityscapes manually. Ensure that there exist following files in the `root` directory before you using this class. :: leftImg8bit_foggy/ train/ val/ test/ gtFine/ train/ val/ test/ """ def __init__(self, root, split='train', data_folder='leftImg8bit_foggy', label_folder='gtFine', beta=0.02, **kwargs): assert beta in [0.02, 0.01, 0.005] self.beta = beta super(FoggyCityscapes, self).__init__(root, split, data_folder, label_folder, **kwargs) def parse_data_file(self, file_name): """Parse file to image list Args: file_name (str): The path of data file Returns: List of image path """ with open(file_name, "r") as f: data_list = [line.strip().replace("leftImg8bit", "leftImg8bit_foggy_beta_{}".format(self.beta)) for line in f.readlines()] return data_list ================================================ FILE: tllib/vision/datasets/segmentation/gta5.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from .segmentation_list import SegmentationList from .cityscapes import Cityscapes from .._util import download as download_data class GTA5(SegmentationList): """`GTA5 `_ Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``. data_folder (str, optional): Sub-directory of the image. Default: 'images'. label_folder (str, optional): Sub-directory of the label. Default: 'labels'. mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None. transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \ and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`. .. note:: You need to download GTA5 manually. Ensure that there exist following directories in the `root` directory before you using this class. :: images/ labels/ """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/f719733e339544e9a330/?dl=1"), ] def __init__(self, root, split='train', data_folder='images', label_folder='labels', **kwargs): assert split in ['train'] # download meta information from Internet list(map(lambda args: download_data(root, *args), self.download_list)) data_list_file = os.path.join(root, "image_list", "{}.txt".format(split)) self.split = split super(GTA5, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, data_folder, label_folder, id_to_train_id=Cityscapes.ID_TO_TRAIN_ID, train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs) ================================================ FILE: tllib/vision/datasets/segmentation/segmentation_list.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from typing import Sequence, Optional, Dict, Callable from PIL import Image import tqdm import numpy as np from torch.utils import data import torch class SegmentationList(data.Dataset): """A generic Dataset class for domain adaptation in image segmentation Args: root (str): Root directory of dataset classes (seq[str]): The names of all the classes data_list_file (str): File to read the image list from. label_list_file (str): File to read the label list from. data_folder (str): Sub-directory of the image. label_folder (str): Sub-directory of the label. mean (seq[float]): mean BGR value. Normalize and convert to the image if not None. Default: None. id_to_train_id (dict, optional): the map between the id on the label and the actual train id. train_id_to_color (seq, optional): the map between the train id and the color. transforms (callable, optional): A function/transform that takes in (PIL Image, label) pair \ and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`. .. note:: In ``data_list_file``, each line is the relative path of an image. If your data_list_file has different formats, please over-ride :meth:`~SegmentationList.parse_data_file`. :: source_dir/dog_xxx.png target_dir/dog_xxy.png In ``label_list_file``, each line is the relative path of an label. If your label_list_file has different formats, please over-ride :meth:`~SegmentationList.parse_label_file`. .. warning:: When mean is not None, please do not provide Normalize and ToTensor in transforms. """ def __init__(self, root: str, classes: Sequence[str], data_list_file: str, label_list_file: str, data_folder: str, label_folder: str, id_to_train_id: Optional[Dict] = None, train_id_to_color: Optional[Sequence] = None, transforms: Optional[Callable] = None): self.root = root self.classes = classes self.data_list_file = data_list_file self.label_list_file = label_list_file self.data_folder = data_folder self.label_folder = label_folder self.ignore_label = 255 self.id_to_train_id = id_to_train_id self.train_id_to_color = np.array(train_id_to_color) self.data_list = self.parse_data_file(self.data_list_file) self.label_list = self.parse_label_file(self.label_list_file) self.transforms = transforms def parse_data_file(self, file_name): """Parse file to image list Args: file_name (str): The path of data file Returns: List of image path """ with open(file_name, "r") as f: data_list = [line.strip() for line in f.readlines()] return data_list def parse_label_file(self, file_name): """Parse file to label list Args: file_name (str): The path of data file Returns: List of label path """ with open(file_name, "r") as f: label_list = [line.strip() for line in f.readlines()] return label_list def __len__(self): return len(self.data_list) def __getitem__(self, index): image_name = self.data_list[index] label_name = self.label_list[index] image = Image.open(os.path.join(self.root, self.data_folder, image_name)).convert('RGB') label = Image.open(os.path.join(self.root, self.label_folder, label_name)) image, label = self.transforms(image, label) # remap label if isinstance(label, torch.Tensor): label = label.numpy() label = np.asarray(label, np.int64) label_copy = self.ignore_label * np.ones(label.shape, dtype=np.int64) if self.id_to_train_id: for k, v in self.id_to_train_id.items(): label_copy[label == k] = v return image, label_copy.copy() @property def num_classes(self) -> int: """Number of classes""" return len(self.classes) def decode_target(self, target): """ Decode label (each value is integer) into the corresponding RGB value. Args: target (numpy.array): label in shape H x W Returns: RGB label (PIL Image) in shape H x W x 3 """ target = target.copy() target[target == 255] = self.num_classes # unknown label is black on the RGB label target = self.train_id_to_color[target] return Image.fromarray(target.astype(np.uint8)) def collect_image_paths(self): """Return a list of the absolute path of all the images""" return [os.path.join(self.root, self.data_folder, image_name) for image_name in self.data_list] @staticmethod def _save_pil_image(image, path): os.makedirs(os.path.dirname(path), exist_ok=True) image.save(path) def translate(self, transform: Callable, target_root: str, color=False): """ Translate an image and save it into a specified directory Args: transform (callable): a transform function that maps (image, label) pair from one domain to another domain target_root (str): the root directory to save images and labels """ os.makedirs(target_root, exist_ok=True) for image_name, label_name in zip(tqdm.tqdm(self.data_list), self.label_list): image_path = os.path.join(target_root, self.data_folder, image_name) label_path = os.path.join(target_root, self.label_folder, label_name) if os.path.exists(image_path) and os.path.exists(label_path): continue image = Image.open(os.path.join(self.root, self.data_folder, image_name)).convert('RGB') label = Image.open(os.path.join(self.root, self.label_folder, label_name)) translated_image, translated_label = transform(image, label) self._save_pil_image(translated_image, image_path) self._save_pil_image(translated_label, label_path) if color: colored_label = self.decode_target(np.array(translated_label)) file_name, file_ext = os.path.splitext(label_name) self._save_pil_image(colored_label, os.path.join(target_root, self.label_folder, "{}_color{}".format(file_name, file_ext))) @property def evaluate_classes(self): """The name of classes to be evaluated""" return self.classes @property def ignore_classes(self): """The name of classes to be ignored""" return list(set(self.classes) - set(self.evaluate_classes)) ================================================ FILE: tllib/vision/datasets/segmentation/synthia.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from .segmentation_list import SegmentationList from .cityscapes import Cityscapes from .._util import download as download_data class Synthia(SegmentationList): """`SYNTHIA `_ Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``. data_folder (str, optional): Sub-directory of the image. Default: 'RGB'. label_folder (str, optional): Sub-directory of the label. Default: 'synthia_mapped_to_cityscapes'. mean (seq[float]): mean BGR value. Normalize the image if not None. Default: None. transforms (callable, optional): A function/transform that takes in (PIL image, label) pair \ and returns a transformed version. E.g, :class:`~tllib.vision.transforms.segmentation.Resize`. .. note:: You need to download GTA5 manually. Ensure that there exist following directories in the `root` directory before you using this class. :: RGB/ synthia_mapped_to_cityscapes/ """ ID_TO_TRAIN_ID = { 3: 0, 4: 1, 2: 2, 21: 3, 5: 4, 7: 5, 15: 6, 9: 7, 6: 8, 16: 9, 1: 10, 10: 11, 17: 12, 8: 13, 18: 14, 19: 15, 20: 16, 12: 17, 11: 18 } download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/1c652d518e0347e2800d/?dl=1"), ] def __init__(self, root, split='train', data_folder='RGB', label_folder='synthia_mapped_to_cityscapes', **kwargs): assert split in ['train'] # download meta information from Internet list(map(lambda args: download_data(root, *args), self.download_list)) data_list_file = os.path.join(root, "image_list", "{}.txt".format(split)) super(Synthia, self).__init__(root, Cityscapes.CLASSES, data_list_file, data_list_file, data_folder, label_folder, id_to_train_id=Synthia.ID_TO_TRAIN_ID, train_id_to_color=Cityscapes.TRAIN_ID_TO_COLOR, **kwargs) @property def evaluate_classes(self): return [ 'road', 'sidewalk', 'building', 'traffic light', 'traffic sign', 'vegetation', 'sky', 'person', 'rider', 'car', 'bus', 'motorcycle', 'bicycle' ] ================================================ FILE: tllib/vision/datasets/stanford_cars.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class StanfordCars(ImageList): """`The Stanford Cars `_ \ contains 16,185 images of 196 classes of cars. \ Each category has been split roughly in a 50-50 split. \ There are 8,144 images for training and 8,041 images for testing. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. sample_rate (int): The sampling rates to sample random ``training`` images for each category. Choices include 100, 50, 30, 15. Default: 100. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: train/ test/ image_list/ train_100.txt train_50.txt train_30.txt train_15.txt test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/aeeb690e9886442aa267/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/fd80c30c120a42a08fd3/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/01e6b279f20440cb8bf9/?dl=1"), ] image_list = { "train": "image_list/train_100.txt", "train100": "image_list/train_100.txt", "train50": "image_list/train_50.txt", "train30": "image_list/train_30.txt", "train15": "image_list/train_15.txt", "test": "image_list/test.txt", "test100": "image_list/test.txt", } CLASSES = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '192', '193', '194', '195', '196'] def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False, **kwargs): if split == 'train': list_name = 'train' + str(sample_rate) assert list_name in self.image_list data_list_file = os.path.join(root, self.image_list[list_name]) else: data_list_file = os.path.join(root, self.image_list['test']) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(StanfordCars, self).__init__(root, StanfordCars.CLASSES, data_list_file=data_list_file, **kwargs) ================================================ FILE: tllib/vision/datasets/stanford_dogs.py ================================================ """ @author: Yifei Ji @contact: jiyf990330@163.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class StanfordDogs(ImageList): """`The Stanford Dogs `_ \ contains 20,580 images of 120 breeds of dogs from around the world. \ Each category is composed of exactly 100 training examples and around 72 testing examples. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. sample_rate (int): The sampling rates to sample random ``training`` images for each category. Choices include 100, 50, 30, 15. Default: 100. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: train/ test/ image_list/ train_100.txt train_50.txt train_30.txt train_15.txt test.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/7685b13c549a4591b429/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/9f19a6d1b14b4f1e8d13/?dl=1"), ("test", "test.tgz", "https://cloud.tsinghua.edu.cn/f/a497b21e31cc4bfc9d45/?dl=1"), ] image_list = { "train": "image_list/train_100.txt", "train100": "image_list/train_100.txt", "train50": "image_list/train_50.txt", "train30": "image_list/train_30.txt", "train15": "image_list/train_15.txt", "test": "image_list/test.txt", "test100": "image_list/test.txt", } CLASSES = ['n02085620-Chihuahua', 'n02085782-Japanese_spaniel', 'n02085936-Maltese_dog', 'n02086079-Pekinese', 'n02086240-Shih-Tzu', 'n02086646-Blenheim_spaniel', 'n02086910-papillon', 'n02087046-toy_terrier', 'n02087394-Rhodesian_ridgeback', 'n02088094-Afghan_hound', 'n02088238-basset', 'n02088364-beagle', 'n02088466-bloodhound', 'n02088632-bluetick', 'n02089078-black-and-tan_coonhound', 'n02089867-Walker_hound', 'n02089973-English_foxhound', 'n02090379-redbone', 'n02090622-borzoi', 'n02090721-Irish_wolfhound', 'n02091032-Italian_greyhound', 'n02091134-whippet', 'n02091244-Ibizan_hound', 'n02091467-Norwegian_elkhound', 'n02091635-otterhound', 'n02091831-Saluki', 'n02092002-Scottish_deerhound', 'n02092339-Weimaraner', 'n02093256-Staffordshire_bullterrier', 'n02093428-American_Staffordshire_terrier', 'n02093647-Bedlington_terrier', 'n02093754-Border_terrier', 'n02093859-Kerry_blue_terrier', 'n02093991-Irish_terrier', 'n02094114-Norfolk_terrier', 'n02094258-Norwich_terrier', 'n02094433-Yorkshire_terrier', 'n02095314-wire-haired_fox_terrier', 'n02095570-Lakeland_terrier', 'n02095889-Sealyham_terrier', 'n02096051-Airedale', 'n02096177-cairn', 'n02096294-Australian_terrier', 'n02096437-Dandie_Dinmont', 'n02096585-Boston_bull', 'n02097047-miniature_schnauzer', 'n02097130-giant_schnauzer', 'n02097209-standard_schnauzer', 'n02097298-Scotch_terrier', 'n02097474-Tibetan_terrier', 'n02097658-silky_terrier', 'n02098105-soft-coated_wheaten_terrier', 'n02098286-West_Highland_white_terrier', 'n02098413-Lhasa', 'n02099267-flat-coated_retriever', 'n02099429-curly-coated_retriever', 'n02099601-golden_retriever', 'n02099712-Labrador_retriever', 'n02099849-Chesapeake_Bay_retriever', 'n02100236-German_short-haired_pointer', 'n02100583-vizsla', 'n02100735-English_setter', 'n02100877-Irish_setter', 'n02101006-Gordon_setter', 'n02101388-Brittany_spaniel', 'n02101556-clumber', 'n02102040-English_springer', 'n02102177-Welsh_springer_spaniel', 'n02102318-cocker_spaniel', 'n02102480-Sussex_spaniel', 'n02102973-Irish_water_spaniel', 'n02104029-kuvasz', 'n02104365-schipperke', 'n02105056-groenendael', 'n02105162-malinois', 'n02105251-briard', 'n02105412-kelpie', 'n02105505-komondor', 'n02105641-Old_English_sheepdog', 'n02105855-Shetland_sheepdog', 'n02106030-collie', 'n02106166-Border_collie', 'n02106382-Bouvier_des_Flandres', 'n02106550-Rottweiler', 'n02106662-German_shepherd', 'n02107142-Doberman', 'n02107312-miniature_pinscher', 'n02107574-Greater_Swiss_Mountain_dog', 'n02107683-Bernese_mountain_dog', 'n02107908-Appenzeller', 'n02108000-EntleBucher', 'n02108089-boxer', 'n02108422-bull_mastiff', 'n02108551-Tibetan_mastiff', 'n02108915-French_bulldog', 'n02109047-Great_Dane', 'n02109525-Saint_Bernard', 'n02109961-Eskimo_dog', 'n02110063-malamute', 'n02110185-Siberian_husky', 'n02110627-affenpinscher', 'n02110806-basenji', 'n02110958-pug', 'n02111129-Leonberg', 'n02111277-Newfoundland', 'n02111500-Great_Pyrenees', 'n02111889-Samoyed', 'n02112018-Pomeranian', 'n02112137-chow', 'n02112350-keeshond', 'n02112706-Brabancon_griffon', 'n02113023-Pembroke', 'n02113186-Cardigan', 'n02113624-toy_poodle', 'n02113712-miniature_poodle', 'n02113799-standard_poodle', 'n02113978-Mexican_hairless', 'n02115641-dingo', 'n02115913-dhole', 'n02116738-African_hunting_dog'] def __init__(self, root: str, split: str, sample_rate: Optional[int] = 100, download: Optional[bool] = False, **kwargs): if split == 'train': list_name = 'train' + str(sample_rate) assert list_name in self.image_list data_list_file = os.path.join(root, self.image_list[list_name]) else: data_list_file = os.path.join(root, self.image_list['test']) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(StanfordDogs, self).__init__(root, StanfordDogs.CLASSES, data_list_file=data_list_file, **kwargs) ================================================ FILE: tllib/vision/datasets/sun397.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import os from .imagelist import ImageList from ._util import download as download_data, check_exits class SUN397(ImageList): """`SUN397 `_ is a dataset for scene understanding with 108,754 images in 397 scene categories. The number of images varies across categories, but there are at least 100 images per category. Note that the authors construct 10 partitions, where each partition contains 50 training images and 50 testing images per class. We adopt partition 1. Args: root (str): Root directory of dataset split (str, optional): The dataset split, supports ``train``, or ``test``. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, :class:`torchvision.transforms.RandomCrop`. target_transform (callable, optional): A function/transform that takes in the target and transforms it. """ dataset_url = ("SUN397", "SUN397.tar.gz", "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz") image_list_url = ( "SUN397/image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/dec0775147c144ea9f75/?dl=1") def __init__(self, root, split='train', download=True, **kwargs): if download: download_data(root, *self.dataset_url) download_data(os.path.join(root, 'SUN397'), *self.image_list_url) else: check_exits(root, "SUN397") check_exits(root, "SUN397/image_list") classes = list([str(i) for i in range(397)]) root = os.path.join(root, 'SUN397') super(SUN397, self).__init__(root, classes, os.path.join(root, 'image_list', '{}.txt'.format(split)), **kwargs) ================================================ FILE: tllib/vision/datasets/visda2017.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import os from typing import Optional from .imagelist import ImageList from ._util import download as download_data, check_exits class VisDA2017(ImageList): """`VisDA-2017 `_ Dataset Args: root (str): Root directory of dataset task (str): The task (domain) to create dataset. Choices include ``'Synthetic'``: synthetic images and \ ``'Real'``: real-world images. download (bool, optional): If true, downloads the dataset from the internet and puts it \ in root directory. If dataset is already downloaded, it is not downloaded again. transform (callable, optional): A function/transform that takes in an PIL image and returns a \ transformed version. E.g, ``transforms.RandomCrop``. target_transform (callable, optional): A function/transform that takes in the target and transforms it. .. note:: In `root`, there will exist following files after downloading. :: train/ aeroplance/ *.png ... validation/ image_list/ train.txt validation.txt """ download_list = [ ("image_list", "image_list.zip", "https://cloud.tsinghua.edu.cn/f/c107de37b8094c5398dc/?dl=1"), ("train", "train.tgz", "https://cloud.tsinghua.edu.cn/f/c5f3ce59139144ec8221/?dl=1"), ("validation", "validation.tgz", "https://cloud.tsinghua.edu.cn/f/da70e4b1cf514ecea562/?dl=1") ] image_list = { "Synthetic": "image_list/train.txt", "Real": "image_list/validation.txt" } CLASSES = ['aeroplane', 'bicycle', 'bus', 'car', 'horse', 'knife', 'motorcycle', 'person', 'plant', 'skateboard', 'train', 'truck'] def __init__(self, root: str, task: str, download: Optional[bool] = False, **kwargs): assert task in self.image_list data_list_file = os.path.join(root, self.image_list[task]) if download: list(map(lambda args: download_data(root, *args), self.download_list)) else: list(map(lambda file_name, _: check_exits(root, file_name), self.download_list)) super(VisDA2017, self).__init__(root, VisDA2017.CLASSES, data_list_file=data_list_file, **kwargs) @classmethod def domains(cls): return list(cls.image_list.keys()) ================================================ FILE: tllib/vision/models/__init__.py ================================================ from .resnet import * from .digits import * __all__ = ['resnet', 'digits'] ================================================ FILE: tllib/vision/models/digits.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn class LeNet(nn.Sequential): def __init__(self, num_classes=10): super(LeNet, self).__init__( nn.Conv2d(1, 20, kernel_size=5), nn.MaxPool2d(2), nn.ReLU(), nn.Conv2d(20, 50, kernel_size=5), nn.Dropout2d(p=0.5), nn.MaxPool2d(2), nn.ReLU(), nn.Flatten(start_dim=1), nn.Linear(50 * 4 * 4, 500), nn.ReLU(), nn.Dropout(p=0.5), ) self.num_classes = num_classes self.out_features = 500 def copy_head(self): return nn.Linear(500, self.num_classes) class DTN(nn.Sequential): def __init__(self, num_classes=10): super(DTN, self).__init__( nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(64), nn.Dropout2d(0.1), nn.ReLU(), nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(128), nn.Dropout2d(0.3), nn.ReLU(), nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2), nn.BatchNorm2d(256), nn.Dropout2d(0.5), nn.ReLU(), nn.Flatten(start_dim=1), nn.Linear(256 * 4 * 4, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(), ) self.num_classes = num_classes self.out_features = 512 def copy_head(self): return nn.Linear(512, self.num_classes) def lenet(pretrained=False, **kwargs): """LeNet model from `"Gradient-based learning applied to document recognition" `_ Args: num_classes (int): number of classes. Default: 10 .. note:: The input image size must be 28 x 28. """ return LeNet(**kwargs) def dtn(pretrained=False, **kwargs): """ DTN model Args: num_classes (int): number of classes. Default: 10 .. note:: The input image size must be 32 x 32. """ return DTN(**kwargs) ================================================ FILE: tllib/vision/models/keypoint_detection/__init__.py ================================================ from .pose_resnet import * from . import loss __all__ = ['pose_resnet'] ================================================ FILE: tllib/vision/models/keypoint_detection/loss.py ================================================ """ Modified from https://github.com/microsoft/human-pose-estimation.pytorch @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn import torch.nn.functional as F class JointsMSELoss(nn.Module): """ Typical MSE loss for keypoint detection. Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output. Default: ``'mean'`` Inputs: - output (tensor): heatmap predictions - target (tensor): heatmap labels - target_weight (tensor): whether the keypoint is visible. All keypoint is visible if None. Default: None. Shape: - output: :math:`(minibatch, K, H, W)` where K means the number of keypoints, H and W is the height and width of the heatmap respectively. - target: :math:`(minibatch, K, H, W)`. - target_weight: :math:`(minibatch, K)`. - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, K)`. """ def __init__(self, reduction='mean'): super(JointsMSELoss, self).__init__() self.criterion = nn.MSELoss(reduction='none') self.reduction = reduction def forward(self, output, target, target_weight=None): B, K, _, _ = output.shape heatmaps_pred = output.reshape((B, K, -1)) heatmaps_gt = target.reshape((B, K, -1)) loss = self.criterion(heatmaps_pred, heatmaps_gt) * 0.5 if target_weight is not None: loss = loss * target_weight.view((B, K, 1)) if self.reduction == 'mean': return loss.mean() elif self.reduction == 'none': return loss.mean(dim=-1) class JointsKLLoss(nn.Module): """ KL Divergence for keypoint detection proposed by `Regressive Domain Adaptation for Unsupervised Keypoint Detection `_. Args: reduction (str, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, ``'mean'``: the sum of the output will be divided by the number of elements in the output. Default: ``'mean'`` Inputs: - output (tensor): heatmap predictions - target (tensor): heatmap labels - target_weight (tensor): whether the keypoint is visible. All keypoint is visible if None. Default: None. Shape: - output: :math:`(minibatch, K, H, W)` where K means the number of keypoints, H and W is the height and width of the heatmap respectively. - target: :math:`(minibatch, K, H, W)`. - target_weight: :math:`(minibatch, K)`. - Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(minibatch, K)`. """ def __init__(self, reduction='mean', epsilon=0.): super(JointsKLLoss, self).__init__() self.criterion = nn.KLDivLoss(reduction='none') self.reduction = reduction self.epsilon = epsilon def forward(self, output, target, target_weight=None): B, K, _, _ = output.shape heatmaps_pred = output.reshape((B, K, -1)) heatmaps_pred = F.log_softmax(heatmaps_pred, dim=-1) heatmaps_gt = target.reshape((B, K, -1)) heatmaps_gt = heatmaps_gt + self.epsilon heatmaps_gt = heatmaps_gt / heatmaps_gt.sum(dim=-1, keepdims=True) loss = self.criterion(heatmaps_pred, heatmaps_gt).sum(dim=-1) if target_weight is not None: loss = loss * target_weight.view((B, K)) if self.reduction == 'mean': return loss.mean() elif self.reduction == 'none': return loss.mean(dim=-1) ================================================ FILE: tllib/vision/models/keypoint_detection/pose_resnet.py ================================================ """ Modified from https://github.com/microsoft/human-pose-estimation.pytorch @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn from ..resnet import _resnet, Bottleneck class Upsampling(nn.Sequential): """ 3-layers deconvolution used in `Simple Baseline `_. """ def __init__(self, in_channel=2048, hidden_dims=(256, 256, 256), kernel_sizes=(4, 4, 4), bias=False): assert len(hidden_dims) == len(kernel_sizes), \ 'ERROR: len(hidden_dims) is different len(kernel_sizes)' layers = [] for hidden_dim, kernel_size in zip(hidden_dims, kernel_sizes): if kernel_size == 4: padding = 1 output_padding = 0 elif kernel_size == 3: padding = 1 output_padding = 1 elif kernel_size == 2: padding = 0 output_padding = 0 else: raise NotImplementedError("kernel_size is {}".format(kernel_size)) layers.append( nn.ConvTranspose2d( in_channels=in_channel, out_channels=hidden_dim, kernel_size=kernel_size, stride=2, padding=padding, output_padding=output_padding, bias=bias)) layers.append(nn.BatchNorm2d(hidden_dim)) layers.append(nn.ReLU(inplace=True)) in_channel = hidden_dim super(Upsampling, self).__init__(*layers) # init following Simple Baseline for name, m in self.named_modules(): if isinstance(m, nn.ConvTranspose2d): nn.init.normal_(m.weight, std=0.001) if bias: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) class PoseResNet(nn.Module): """ `Simple Baseline `_ for keypoint detection. Args: backbone (torch.nn.Module): Backbone to extract 2-d features from data upsampling (torch.nn.Module): Layer to upsample image feature to heatmap size feature_dim (int): The dimension of the features from upsampling layer. num_keypoints (int): Number of keypoints finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: False """ def __init__(self, backbone, upsampling, feature_dim, num_keypoints, finetune=False): super(PoseResNet, self).__init__() self.backbone = backbone self.upsampling = upsampling self.head = nn.Conv2d(in_channels=feature_dim, out_channels=num_keypoints, kernel_size=1, stride=1, padding=0) self.finetune = finetune for m in self.head.modules(): nn.init.normal_(m.weight, std=0.001) nn.init.constant_(m.bias, 0) def forward(self, x): x = self.backbone(x) x = self.upsampling(x) x = self.head(x) return x def get_parameters(self, lr=1.): return [ {'params': self.backbone.parameters(), 'lr': 0.1 * lr if self.finetune else lr}, {'params': self.upsampling.parameters(), 'lr': lr}, {'params': self.head.parameters(), 'lr': lr}, ] def _pose_resnet(arch, num_keypoints, block, layers, pretrained_backbone, deconv_with_bias, finetune=False, progress=True, **kwargs): backbone = _resnet(arch, block, layers, pretrained_backbone, progress, **kwargs) upsampling = Upsampling(backbone.out_features, bias=deconv_with_bias) model = PoseResNet(backbone, upsampling, 256, num_keypoints, finetune) return model def pose_resnet101(num_keypoints, pretrained_backbone=True, deconv_with_bias=False, finetune=False, progress=True, **kwargs): """Constructs a Simple Baseline model with a ResNet-101 backbone. Args: num_keypoints (int): number of keypoints pretrained_backbone (bool, optional): If True, returns a model pre-trained on ImageNet. Default: True. deconv_with_bias (bool, optional): Whether use bias in the deconvolution layer. Default: False finetune (bool, optional): Whether use 10x smaller learning rate in the backbone. Default: False progress (bool, optional): If True, displays a progress bar of the download to stderr. Default: True """ return _pose_resnet('resnet101', num_keypoints, Bottleneck, [3, 4, 23, 3], pretrained_backbone, deconv_with_bias, finetune, progress, **kwargs) ================================================ FILE: tllib/vision/models/object_detection/__init__.py ================================================ from . import meta_arch from . import roi_heads from . import proposal_generator from . import backbone ================================================ FILE: tllib/vision/models/object_detection/backbone/__init__.py ================================================ from .vgg import VGG, build_vgg_fpn_backbone ================================================ FILE: tllib/vision/models/object_detection/backbone/mmdetection/vgg.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. # Source: https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/vgg.py from mmcv.runner import load_checkpoint import torch.nn as nn from .weight_init import constant_init, kaiming_init, normal_init def conv3x3(in_planes, out_planes, dilation=1): "3x3 convolution with padding" return nn.Conv2d( in_planes, out_planes, kernel_size=3, padding=dilation, dilation=dilation) def make_vgg_layer(inplanes, planes, num_blocks, dilation=1, with_bn=False, ceil_mode=False): layers = [] for _ in range(num_blocks): layers.append(conv3x3(inplanes, planes, dilation)) if with_bn: layers.append(nn.BatchNorm2d(planes)) layers.append(nn.ReLU(inplace=True)) inplanes = planes layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) return layers class VGG(nn.Module): """VGG backbone. Args: depth (int): Depth of vgg, from {11, 13, 16, 19}. with_bn (bool): Use BatchNorm or not. num_classes (int): number of classes for classification. num_stages (int): VGG stages, normally 5. dilations (Sequence[int]): Dilation of each stage. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze running stats (mean and var). bn_frozen (bool): Whether to freeze weight and bias of BN layers. """ arch_settings = { 11: (1, 1, 2, 2, 2), 13: (2, 2, 2, 2, 2), 16: (2, 2, 3, 3, 3), 19: (2, 2, 4, 4, 4) } def __init__(self, depth, with_bn=False, num_classes=-1, num_stages=5, dilations=(1, 1, 1, 1, 1), out_indices=(0, 1, 2, 3, 4), frozen_stages=-1, bn_eval=True, bn_frozen=False, ceil_mode=False, with_last_pool=True): super(VGG, self).__init__() if depth not in self.arch_settings: raise KeyError('invalid depth {} for vgg'.format(depth)) assert num_stages >= 1 and num_stages <= 5 stage_blocks = self.arch_settings[depth] self.stage_blocks = stage_blocks[:num_stages] assert len(dilations) == num_stages assert max(out_indices) <= num_stages self.num_classes = num_classes self.out_indices = out_indices self.frozen_stages = frozen_stages self.bn_eval = bn_eval self.bn_frozen = bn_frozen self.inplanes = 3 start_idx = 0 vgg_layers = [] self.range_sub_modules = [] for i, num_blocks in enumerate(self.stage_blocks): num_modules = num_blocks * (2 + with_bn) + 1 end_idx = start_idx + num_modules dilation = dilations[i] planes = 64 * 2**i if i < 4 else 512 vgg_layer = make_vgg_layer( self.inplanes, planes, num_blocks, dilation=dilation, with_bn=with_bn, ceil_mode=ceil_mode) vgg_layers.extend(vgg_layer) self.inplanes = planes self.range_sub_modules.append([start_idx, end_idx]) start_idx = end_idx if not with_last_pool: vgg_layers.pop(-1) self.range_sub_modules[-1][1] -= 1 self.module_name = 'features' self.add_module(self.module_name, nn.Sequential(*vgg_layers)) if self.num_classes > 0: self.classifier = nn.Sequential( nn.Linear(512 * 7 * 7, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, 4096), nn.ReLU(True), nn.Dropout(), nn.Linear(4096, num_classes), ) # initialize the model by random self.init_weights() # Optionally freeze (requires_grad=False) parts of the backbone self._freeze_backbone(self.frozen_stages) def _freeze_backbone(self, freeze_at): if freeze_at < 0: return vgg_layers = getattr(self, self.module_name) for i in range(freeze_at): for j in range(*self.range_sub_modules[i]): mod = vgg_layers[j] mod.eval() for param in mod.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): if isinstance(pretrained, str): load_checkpoint(self, pretrained, strict=False) elif pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_init(m) elif isinstance(m, nn.BatchNorm2d): constant_init(m, 1) elif isinstance(m, nn.Linear): normal_init(m, std=0.01) else: raise TypeError('pretrained must be a str or None') def forward(self, x): outs = [] vgg_layers = getattr(self, self.module_name) for i, num_blocks in enumerate(self.stage_blocks): for j in range(*self.range_sub_modules[i]): vgg_layer = vgg_layers[j] x = vgg_layer(x) if i in self.out_indices: outs.append(x) if self.num_classes > 0: x = x.view(x.size(0), -1) x = self.classifier(x) outs.append(x) if len(outs) == 1: return outs[0] else: return tuple(outs) ================================================ FILE: tllib/vision/models/object_detection/backbone/mmdetection/weight_init.py ================================================ # Copyright (c) Open-MMLab. All rights reserved. # Source: https://github.com/open-mmlab/mmcv/blob/master/mmcv/cnn/utils/weight_init.py import numpy as np import torch.nn as nn def constant_init(module, val, bias=0): nn.init.constant_(module.weight, val) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def xavier_init(module, gain=1, bias=0, distribution='normal'): assert distribution in ['uniform', 'normal'] if distribution == 'uniform': nn.init.xavier_uniform_(module.weight, gain=gain) else: nn.init.xavier_normal_(module.weight, gain=gain) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def normal_init(module, mean=0, std=1, bias=0): nn.init.normal_(module.weight, mean, std) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def uniform_init(module, a=0, b=1, bias=0): nn.init.uniform_(module.weight, a, b) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def kaiming_init(module, a=0, mode='fan_out', nonlinearity='relu', bias=0, distribution='normal'): assert distribution in ['uniform', 'normal'] if distribution == 'uniform': nn.init.kaiming_uniform_( module.weight, a=a, mode=mode, nonlinearity=nonlinearity) else: nn.init.kaiming_normal_( module.weight, a=a, mode=mode, nonlinearity=nonlinearity) if hasattr(module, 'bias') and module.bias is not None: nn.init.constant_(module.bias, bias) def caffe2_xavier_init(module, bias=0): # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch # Acknowledgment to FAIR's internal code kaiming_init( module, a=1, mode='fan_in', nonlinearity='leaky_relu', distribution='uniform') ================================================ FILE: tllib/vision/models/object_detection/backbone/vgg.py ================================================ # referece from https://github.com/chengchunhsu/EveryPixelMatters import torch import torch.nn.functional as F from torch import nn from detectron2.modeling.backbone import Backbone, BACKBONE_REGISTRY from .mmdetection.vgg import VGG class FPN(nn.Module): """ Module that adds FPN on top of a list of feature maps. The feature maps are currently supposed to be in increasing depth order, and must be consecutive """ def __init__( self, in_channels_list, out_channels, conv_block, top_blocks=None ): """ Arguments: in_channels_list (list[int]): number of channels for each feature map that will be fed out_channels (int): number of channels of the FPN representation top_blocks (nn.Module or None): if provided, an extra operation will be performed on the output of the last (smallest resolution) FPN output, and the result will extend the result list """ super(FPN, self).__init__() self.inner_blocks = [] self.layer_blocks = [] for idx, in_channels in enumerate(in_channels_list, 1): inner_block = "fpn_inner{}".format(idx) layer_block = "fpn_layer{}".format(idx) if in_channels == 0: continue inner_block_module = conv_block(in_channels, out_channels, 1) layer_block_module = conv_block(out_channels, out_channels, 3, 1) self.add_module(inner_block, inner_block_module) self.add_module(layer_block, layer_block_module) self.inner_blocks.append(inner_block) self.layer_blocks.append(layer_block) self.top_blocks = top_blocks def forward(self, x): """ Arguments: x (list[Tensor]): feature maps for each feature level. Returns: results (tuple[Tensor]): feature maps after FPN layers. They are ordered from highest resolution first. """ last_inner = getattr(self, self.inner_blocks[-1])(x[-1]) results = [] results.append(getattr(self, self.layer_blocks[-1])(last_inner)) for feature, inner_block, layer_block in zip( x[:-1][::-1], self.inner_blocks[:-1][::-1], self.layer_blocks[:-1][::-1] ): if not inner_block: continue # inner_top_down = F.interpolate(last_inner, scale_factor=2, mode="nearest") inner_lateral = getattr(self, inner_block)(feature) # TODO use size instead of scale to make it robust to different sizes inner_top_down = F.upsample(last_inner, size=inner_lateral.shape[-2:], mode='bilinear', align_corners=False) last_inner = inner_lateral + inner_top_down results.insert(0, getattr(self, layer_block)(last_inner)) if isinstance(self.top_blocks, LastLevelP6P7): last_results = self.top_blocks(x[-1], results[-1]) results.extend(last_results) elif isinstance(self.top_blocks, LastLevelMaxPool): last_results = self.top_blocks(results[-1]) results.extend(last_results) return tuple(results) class LastLevelMaxPool(nn.Module): def forward(self, x): return [F.max_pool2d(x, 1, 2, 0)] class LastLevelP6P7(nn.Module): """ This module is used in RetinaNet to generate extra layers, P6 and P7. """ def __init__(self, in_channels, out_channels): super(LastLevelP6P7, self).__init__() self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) for module in [self.p6, self.p7]: nn.init.kaiming_uniform_(module.weight, a=1) nn.init.constant_(module.bias, 0) self.use_P5 = in_channels == out_channels def forward(self, c5, p5): x = p5 if self.use_P5 else c5 p6 = self.p6(x) p7 = self.p7(F.relu(p6)) return [p6, p7] class _NewEmptyTensorOp(torch.autograd.Function): @staticmethod def forward(ctx, x, new_shape): ctx.shape = x.shape return x.new_empty(new_shape) @staticmethod def backward(ctx, grad): shape = ctx.shape return _NewEmptyTensorOp.apply(grad, shape), None class Conv2d(torch.nn.Conv2d): def forward(self, x): if x.numel() > 0: return super(Conv2d, self).forward(x) # get output shape output_shape = [ (i + 2 * p - (di * (k - 1) + 1)) // d + 1 for i, p, di, k, d in zip( x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride ) ] output_shape = [x.shape[0], self.weight.shape[0]] + output_shape return _NewEmptyTensorOp.apply(x, output_shape) def conv_with_kaiming_uniform(): def make_conv( in_channels, out_channels, kernel_size, stride=1, dilation=1 ): conv = Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=dilation * (kernel_size - 1) // 2, dilation=dilation, bias=True ) # Caffe2 implementation uses XavierFill, which in fact # corresponds to kaiming_uniform_ in PyTorch nn.init.kaiming_uniform_(conv.weight, a=1) nn.init.constant_(conv.bias, 0) module = [conv,] if len(module) > 1: return nn.Sequential(*module) return conv return make_conv class VGGFPN(Backbone): def __init__(self, body, fpn): super(VGGFPN, self).__init__() self.body = body self.fpn = fpn self._out_features = ["p3", "p4", "p5", "p6", "p7"] self._out_feature_channels = { "p3": 256, "p4": 256, "p5": 256, "p6": 256, "p7": 256 } self._out_feature_strides = { "p3": 8, "p4": 16, "p5": 32, "p6": 64, "p7": 128 } def forward(self, x): # print(x.shape) f = self.body(x) f = self.fpn(f) return { name: feature for name, feature in zip(self._out_features, f) } @BACKBONE_REGISTRY.register() def build_vgg_fpn_backbone(cfg, input_shape): body = VGG(depth=16, with_last_pool=True, frozen_stages=2) body.init_weights(cfg.MODEL.WEIGHTS) in_channels_stage2 = 128 # default: cfg.MODEL.RESNETS.RES2_OUT_CHANNELS (256) out_channels = 256 # default: cfg.MODEL.RESNETS.BACKBONE_OUT_CHANNELS (256) in_channels_p6p7 = out_channels fpn = FPN( in_channels_list=[ 0, 0, in_channels_stage2 * 2, in_channels_stage2 * 4, in_channels_stage2 * 4, # in_channels_stage2 * 8 ], out_channels=out_channels, conv_block=conv_with_kaiming_uniform(), top_blocks=LastLevelP6P7(in_channels_p6p7, out_channels), ) model = VGGFPN(body, fpn) model.out_channels = out_channels return model ================================================ FILE: tllib/vision/models/object_detection/meta_arch/__init__.py ================================================ from .rcnn import TLGeneralizedRCNN from .retinanet import TLRetinaNet ================================================ FILE: tllib/vision/models/object_detection/meta_arch/rcnn.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Tuple, Dict import torch from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN as GeneralizedRCNNBase, get_event_storage from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY @META_ARCH_REGISTRY.register() class TLGeneralizedRCNN(GeneralizedRCNNBase): """ Generalized R-CNN for Transfer Learning. Similar to that in in Supervised Learning, TLGeneralizedRCNN has the following three components: 1. Per-image feature extraction (aka backbone) 2. Region proposal generation 3. Per-region feature extraction and prediction Different from that in Supervised Learning, TLGeneralizedRCNN 1. accepts unlabeled images during training (return no losses) 2. return both detection outputs, features, and losses during training Args: backbone: a backbone module, must follow detectron2's backbone interface proposal_generator: a module that generates proposals using backbone features roi_heads: a ROI head that performs per-region computation pixel_mean, pixel_std: list or tuple with #channels element, representing the per-channel mean and std to be used to normalize the input image input_format: describe the meaning of channels of input. Needed by visualization vis_period: the period to run visualization. Set to 0 to disable. finetune (bool): whether finetune the detector or train from scratch. Default: True Inputs: - batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * image: Tensor, image in (C, H, W) format. * instances (optional): groundtruth :class:`Instances` * proposals (optional): :class:`Instances`, precomputed proposals. * "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. - labeled (bool, optional): whether has ground-truth label Outputs: - outputs: A list of dict where each dict is the output for one input image. The dict contains a key "instances" whose value is a :class:`Instances` and a key "features" whose value is the features of middle layers. The :class:`Instances` object has the following keys: "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" - losses: A dict of different losses """ def __init__(self, *args, finetune=False, **kwargs): super().__init__(*args, **kwargs) self.finetune = finetune def forward(self, batched_inputs: Tuple[Dict[str, torch.Tensor]], labeled=True): """""" if not self.training: return self.inference(batched_inputs) images = self.preprocess_image(batched_inputs) if "instances" in batched_inputs[0] and labeled: gt_instances = [x["instances"].to(self.device) for x in batched_inputs] else: gt_instances = None features = self.backbone(images.tensor) if self.proposal_generator is not None: proposals, proposal_losses = self.proposal_generator(images, features, gt_instances, labeled) else: assert "proposals" in batched_inputs[0] proposals = [x["proposals"].to(self.device) for x in batched_inputs] proposal_losses = {} outputs, detector_losses = self.roi_heads(images, features, proposals, gt_instances, labeled) if self.vis_period > 0: storage = get_event_storage() if storage.iter % self.vis_period == 0: self.visualize_training(batched_inputs, proposals) losses = {} losses.update(detector_losses) losses.update(proposal_losses) outputs['features'] = features return outputs, losses def get_parameters(self, lr=1.): """Return a parameter list which decides optimization hyper-parameters, such as the learning rate of each layer """ return [ (self.backbone, 0.1 * lr if self.finetune else lr), (self.proposal_generator, lr), (self.roi_heads, lr), ] ================================================ FILE: tllib/vision/models/object_detection/meta_arch/retinanet.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Dict, List, Tuple import torch from torch import Tensor, nn from detectron2.modeling.meta_arch.retinanet import RetinaNet as RetinaNetBase from detectron2.modeling import detector_postprocess class TLRetinaNet(RetinaNetBase): """ RetinaNet for Transfer Learning. Different from that in Supervised Learning, TLRetinaNet 1. accepts unlabeled images during training (return no losses) 2. return both detection outputs, features, and losses during training Args: backbone: a backbone module, must follow detectron2's backbone interface head (nn.Module): a module that predicts logits and regression deltas for each level from a list of per-level features head_in_features (Tuple[str]): Names of the input feature maps to be used in head anchor_generator (nn.Module): a module that creates anchors from a list of features. Usually an instance of :class:`AnchorGenerator` box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to instance boxes anchor_matcher (Matcher): label the anchors by matching them with ground truth. num_classes (int): number of classes. Used to label background proposals. # Loss parameters: focal_loss_alpha (float): focal_loss_alpha focal_loss_gamma (float): focal_loss_gamma smooth_l1_beta (float): smooth_l1_beta box_reg_loss_type (str): Options are "smooth_l1", "giou" # Inference parameters: test_score_thresh (float): Inference cls score threshold, only anchors with score > INFERENCE_TH are considered for inference (to improve speed) test_topk_candidates (int): Select topk candidates before NMS test_nms_thresh (float): Overlap threshold used for non-maximum suppression (suppress boxes with IoU >= this threshold) max_detections_per_image (int): Maximum number of detections to return per image during inference (100 is based on the limit established for the COCO dataset). # Input parameters pixel_mean (Tuple[float]): Values to be used for image normalization (BGR order). To train on images of different number of channels, set different mean & std. Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675] pixel_std (Tuple[float]): When using pre-trained models in Detectron1 or any MSRA models, std has been absorbed into its conv1 weights, so the std needs to be set 1. Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) vis_period (int): The period (in terms of steps) for minibatch visualization at train time. Set to 0 to disable. input_format (str): Whether the model needs RGB, YUV, HSV etc. finetune (bool): whether finetune the detector or train from scratch. Default: True Inputs: - batched_inputs: a list, batched outputs of :class:`DatasetMapper`. Each item in the list contains the inputs for one image. For now, each item in the list is a dict that contains: * image: Tensor, image in (C, H, W) format. * instances (optional): groundtruth :class:`Instances` * "height", "width" (int): the output resolution of the model, used in inference. See :meth:`postprocess` for details. - labeled (bool, optional): whether has ground-truth label Outputs: - outputs: A list of dict where each dict is the output for one input image. The dict contains a key "instances" whose value is a :class:`Instances` and a key "features" whose value is the features of middle layers. The :class:`Instances` object has the following keys: "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" - losses: A dict of different losses """ def __init__(self, *args, finetune=False, **kwargs): super().__init__(*args, **kwargs) self.finetune = finetune def forward(self, batched_inputs: Tuple[Dict[str, Tensor]], labeled=True): """""" images = self.preprocess_image(batched_inputs) features = self.backbone(images.tensor) features = [features[f] for f in self.head_in_features] predictions = self.head(features) if self.training: if labeled: assert not torch.jit.is_scripting(), "Not supported" assert "instances" in batched_inputs[0], "Instance annotations are missing in training!" gt_instances = [x["instances"].to(self.device) for x in batched_inputs] losses = self.forward_training(images, features, predictions, gt_instances) else: losses = {} outputs = {"features": features} return outputs, losses else: results = self.forward_inference(images, features, predictions) if torch.jit.is_scripting(): return results processed_results = [] for results_per_image, input_per_image, image_size in zip( results, batched_inputs, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) r = detector_postprocess(results_per_image, height, width) processed_results.append({"instances": r}) return processed_results def get_parameters(self, lr=1.): """Return a parameter list which decides optimization hyper-parameters, such as the learning rate of each layer """ return [ (self.backbone.bottom_up, 0.1 * lr if self.finetune else lr), (self.backbone.fpn_lateral4, lr), (self.backbone.fpn_output4, lr), (self.backbone.fpn_lateral5, lr), (self.backbone.fpn_output5, lr), (self.backbone.top_block, lr), (self.head, lr), ] ================================================ FILE: tllib/vision/models/object_detection/proposal_generator/__init__.py ================================================ from .rpn import TLRPN ================================================ FILE: tllib/vision/models/object_detection/proposal_generator/rpn.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from typing import Dict, Optional, List import torch from detectron2.structures import ImageList, Instances from detectron2.modeling.proposal_generator import ( RPN, PROPOSAL_GENERATOR_REGISTRY, ) @PROPOSAL_GENERATOR_REGISTRY.register() class TLRPN(RPN): """ Region Proposal Network, introduced by `Faster R-CNN`. Args: in_features (list[str]): list of names of input features to use head (nn.Module): a module that predicts logits and regression deltas for each level from a list of per-level features anchor_generator (nn.Module): a module that creates anchors from a list of features. Usually an instance of :class:`AnchorGenerator` anchor_matcher (Matcher): label the anchors by matching them with ground truth. box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to instance boxes batch_size_per_image (int): number of anchors per image to sample for training positive_fraction (float): fraction of foreground anchors to sample for training pre_nms_topk (tuple[float]): (train, test) that represents the number of top k proposals to select before NMS, in training and testing. post_nms_topk (tuple[float]): (train, test) that represents the number of top k proposals to select after NMS, in training and testing. nms_thresh (float): NMS threshold used to de-duplicate the predicted proposals min_box_size (float): remove proposal boxes with any side smaller than this threshold, in the unit of input image pixels anchor_boundary_thresh (float): legacy option loss_weight (float|dict): weights to use for losses. Can be single float for weighting all rpn losses together, or a dict of individual weightings. Valid dict keys are: "loss_rpn_cls" - applied to classification loss "loss_rpn_loc" - applied to box regression loss box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou". smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1" Inputs: - images (ImageList): input images of length `N` - features (dict[str, Tensor]): input data as a mapping from feature map name to tensor. Axis 0 represents the number of images `N` in the input data; axes 1-3 are channels, height, and width, which may vary between feature maps (e.g., if a feature pyramid is used). - gt_instances (list[Instances], optional): a length `N` list of `Instances`s. Each `Instances` stores ground-truth instances for the corresponding image. - labeled (bool, optional): whether has ground-truth label. Default: True Outputs: - proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits" - loss: dict[Tensor] or None """ def __init__(self, *args, **kwargs): super(TLRPN, self).__init__(*args, **kwargs) def forward( self, images: ImageList, features: Dict[str, torch.Tensor], gt_instances: Optional[List[Instances]] = None, labeled: Optional[bool] = True ): features = [features[f] for f in self.in_features] # print(torch.max(features[0])) anchors = self.anchor_generator(features) pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features) # Transpose the Hi*Wi*A dimension to the middle: pred_objectness_logits = [ # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A) score.permute(0, 2, 3, 1).flatten(1) for score in pred_objectness_logits ] pred_anchor_deltas = [ # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B) x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1]) .permute(0, 3, 4, 1, 2) .flatten(1, -2) for x in pred_anchor_deltas ] if self.training and labeled: gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances) losses = self.losses( anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes ) else: losses = {} proposals = self.predict_proposals( anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes ) return proposals, losses ================================================ FILE: tllib/vision/models/object_detection/roi_heads/__init__.py ================================================ from .roi_heads import TLRes5ROIHeads, TLStandardROIHeads ================================================ FILE: tllib/vision/models/object_detection/roi_heads/roi_heads.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch from typing import List, Dict from detectron2.structures import Instances from detectron2.modeling.roi_heads import ( ROI_HEADS_REGISTRY, Res5ROIHeads, StandardROIHeads, select_foreground_proposals, ) @ROI_HEADS_REGISTRY.register() class TLRes5ROIHeads(Res5ROIHeads): """ The ROIHeads in a typical "C4" R-CNN model, where the box and mask head share the cropping and the per-region feature computation by a Res5 block. Args: in_features (list[str]): list of backbone feature map names to use for feature extraction pooler (ROIPooler): pooler to extra region features from backbone res5 (nn.Sequential): a CNN to compute per-region features, to be used by ``box_predictor`` and ``mask_head``. Typically this is a "res5" block from a ResNet. box_predictor (nn.Module): make box predictions from the feature. Should have the same interface as :class:`FastRCNNOutputLayers`. mask_head (nn.Module): transform features to make mask predictions Inputs: - images (ImageList): - features (dict[str,Tensor]): input data as a mapping from feature map name to tensor. Axis 0 represents the number of images `N` in the input data; axes 1-3 are channels, height, and width, which may vary between feature maps (e.g., if a feature pyramid is used). - proposals (list[Instances]): length `N` list of `Instances`. The i-th `Instances` contains object proposals for the i-th input image, with fields "proposal_boxes" and "objectness_logits". - targets (list[Instances], optional): length `N` list of `Instances`. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. It may have the following fields: - gt_boxes: the bounding box of each instance. - gt_classes: the label for each instance with a category ranging in [0, #class]. - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance. - gt_keypoints: NxKx3, the groud-truth keypoints for each instance. - labeled (bool, optional): whether has ground-truth label. Default: True Outputs: - list[Instances]: length `N` list of `Instances` containing the detected instances. Returned during inference only; may be [] during training. - dict[str->Tensor]: mapping from a named loss to a tensor storing the loss. Used during training only. """ def __init__(self, *args, **kwargs): super(TLRes5ROIHeads, self).__init__(*args, **kwargs) def forward(self, images, features, proposals, targets=None, labeled=True): """""" del images if self.training: if labeled: proposals = self.label_and_sample_proposals(proposals, targets) else: proposals = self.sample_unlabeled_proposals(proposals) del targets proposal_boxes = [x.proposal_boxes for x in proposals] box_features = self._shared_roi_transform( [features[f] for f in self.in_features], proposal_boxes ) predictions = self.box_predictor(box_features.mean(dim=[2, 3])) if self.training: del features if labeled: losses = self.box_predictor.losses(predictions, proposals) if self.mask_on: proposals, fg_selection_masks = select_foreground_proposals( proposals, self.num_classes ) # Since the ROI feature transform is shared between boxes and masks, # we don't need to recompute features. The mask loss is only defined # on foreground proposals, so we need to select out the foreground # features. mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] # del box_features losses.update(self.mask_head(mask_features, proposals)) else: losses = {} outputs = { 'predictions': predictions[0], 'box_features': box_features } return outputs, losses else: pred_instances, _ = self.box_predictor.inference(predictions, proposals) pred_instances = self.forward_with_given_boxes(features, pred_instances) return pred_instances, {} @torch.no_grad() def sample_unlabeled_proposals( self, proposals: List[Instances] ) -> List[Instances]: """ Prepare some unlabeled proposals. It returns top ``self.batch_size_per_image`` samples from proposals Args: proposals (list[Instances]): length `N` list of `Instances`. The i-th `Instances` contains object proposals for the i-th input image, with fields "proposal_boxes" and "objectness_logits". Returns: length `N` list of `Instances`s containing the proposals sampled for training. """ return [proposal[:self.batch_size_per_image] for proposal in proposals] @ROI_HEADS_REGISTRY.register() class TLStandardROIHeads(StandardROIHeads): """ It's "standard" in a sense that there is no ROI transform sharing or feature sharing between tasks. Each head independently processes the input features by each head's own pooler and head. Args: box_in_features (list[str]): list of feature names to use for the box head. box_pooler (ROIPooler): pooler to extra region features for box head box_head (nn.Module): transform features to make box predictions box_predictor (nn.Module): make box predictions from the feature. Should have the same interface as :class:`FastRCNNOutputLayers`. mask_in_features (list[str]): list of feature names to use for the mask pooler or mask head. None if not using mask head. mask_pooler (ROIPooler): pooler to extract region features from image features. The mask head will then take region features to make predictions. If None, the mask head will directly take the dict of image features defined by `mask_in_features` mask_head (nn.Module): transform features to make mask predictions keypoint_in_features, keypoint_pooler, keypoint_head: similar to ``mask_*``. train_on_pred_boxes (bool): whether to use proposal boxes or predicted boxes from the box head to train other heads. Inputs: - images (ImageList): - features (dict[str,Tensor]): input data as a mapping from feature map name to tensor. Axis 0 represents the number of images `N` in the input data; axes 1-3 are channels, height, and width, which may vary between feature maps (e.g., if a feature pyramid is used). - proposals (list[Instances]): length `N` list of `Instances`. The i-th `Instances` contains object proposals for the i-th input image, with fields "proposal_boxes" and "objectness_logits". - targets (list[Instances], optional): length `N` list of `Instances`. The i-th `Instances` contains the ground-truth per-instance annotations for the i-th input image. Specify `targets` during training only. It may have the following fields: - gt_boxes: the bounding box of each instance. - gt_classes: the label for each instance with a category ranging in [0, #class]. - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance. - gt_keypoints: NxKx3, the groud-truth keypoints for each instance. - labeled (bool, optional): whether has ground-truth label. Default: True Outputs: - list[Instances]: length `N` list of `Instances` containing the detected instances. Returned during inference only; may be [] during training. - dict[str->Tensor]: mapping from a named loss to a tensor storing the loss. Used during training only. """ def __init__(self, *args, **kwargs): super(TLStandardROIHeads, self).__init__(*args, **kwargs) def forward(self, images, features, proposals, targets=None, labeled=True): """""" del images if self.training: if labeled: proposals = self.label_and_sample_proposals(proposals, targets) else: proposals = self.sample_unlabeled_proposals(proposals) del targets if self.training: if labeled: outputs, losses = self._forward_box(features, proposals) # Usually the original proposals used by the box head are used by the mask, keypoint # heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes # predicted by the box head. losses.update(self._forward_mask(features, proposals)) losses.update(self._forward_keypoint(features, proposals)) else: losses = {} return outputs, losses else: pred_instances = self._forward_box(features, proposals) # During inference cascaded prediction is used: the mask and keypoints heads are only # applied to the top scoring box detections. pred_instances = self.forward_with_given_boxes(features, pred_instances) return pred_instances, {} def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]): """ Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`, the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument. Args: features (dict[str, Tensor]): mapping from feature map names to tensor. Same as in :meth:`ROIHeads.forward`. proposals (list[Instances]): the per-image object proposals with their matching ground truth. Each has fields "proposal_boxes", and "objectness_logits", "gt_classes", "gt_boxes". Returns: In training, a dict of losses. In inference, a list of `Instances`, the predicted instances. """ features = [features[f] for f in self.box_in_features] box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals]) box_features = self.box_head(box_features) predictions = self.box_predictor(box_features) if self.training: losses = self.box_predictor.losses(predictions, proposals) # proposals is modified in-place below, so losses must be computed first. if self.train_on_pred_boxes: with torch.no_grad(): pred_boxes = self.box_predictor.predict_boxes_for_gt_classes( predictions, proposals ) for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes): proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image) outputs = { 'predictions': predictions[0], 'box_features': box_features } return outputs, losses else: pred_instances, _ = self.box_predictor.inference(predictions, proposals) return pred_instances @torch.no_grad() def sample_unlabeled_proposals( self, proposals: List[Instances] ) -> List[Instances]: """ Prepare some unlabeled proposals. It returns top ``self.batch_size_per_image`` samples from proposals Args: proposals (list[Instances]): length `N` list of `Instances`. The i-th `Instances` contains object proposals for the i-th input image, with fields "proposal_boxes" and "objectness_logits". Returns: length `N` list of `Instances`s containing the proposals sampled for training. """ return [proposal[:self.batch_size_per_image] for proposal in proposals] ================================================ FILE: tllib/vision/models/reid/__init__.py ================================================ from .resnet import * __all__ = ['resnet'] ================================================ FILE: tllib/vision/models/reid/identifier.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from typing import List, Dict, Optional import torch import torch.nn as nn from torch.nn import init class ReIdentifier(nn.Module): r"""Person reIdentifier from `Bag of Tricks and A Strong Baseline for Deep Person Re-identification (CVPR 2019) `_. Given 2-d features :math:`f` from backbone network, the authors pass :math:`f` through another `BatchNorm1d` layer and get :math:`bn\_f`, which will then pass through a `Linear` layer to output predictions. During training, we use :math:`f` to compute triplet loss. While during testing, :math:`bn\_f` is used as feature. This may be a little confusing. The figures in the origin paper will help you understand better. """ def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: Optional[nn.Module] = None, bottleneck_dim: Optional[int] = -1, finetune=True, pool_layer=None): super(ReIdentifier, self).__init__() if pool_layer is None: self.pool_layer = nn.Sequential( nn.AdaptiveAvgPool2d(output_size=(1, 1)), nn.Flatten() ) else: self.pool_layer = pool_layer self.backbone = backbone self.num_classes = num_classes if bottleneck is None: feature_bn = nn.BatchNorm1d(backbone.out_features) self.bottleneck = feature_bn self._features_dim = backbone.out_features else: feature_bn = nn.BatchNorm1d(bottleneck_dim) self.bottleneck = nn.Sequential( bottleneck, feature_bn ) self._features_dim = bottleneck_dim self.head = nn.Linear(self.features_dim, num_classes, bias=False) self.finetune = finetune # initialize feature_bn and head feature_bn.bias.requires_grad_(False) init.constant_(feature_bn.weight, 1) init.constant_(feature_bn.bias, 0) init.normal_(self.head.weight, std=0.001) @property def features_dim(self) -> int: """The dimension of features before the final `head` layer""" return self._features_dim def forward(self, x: torch.Tensor): """""" f = self.pool_layer(self.backbone(x)) bn_f = self.bottleneck(f) if not self.training: return bn_f predictions = self.head(bn_f) return predictions, f def get_parameters(self, base_lr=1.0, rate=0.1) -> List[Dict]: """A parameter list which decides optimization hyper-parameters, such as the relative learning rate of each layer """ params = [ {"params": self.backbone.parameters(), "lr": rate * base_lr if self.finetune else 1.0 * base_lr}, {"params": self.bottleneck.parameters(), "lr": 1.0 * base_lr}, {"params": self.head.parameters(), "lr": 1.0 * base_lr}, ] return params ================================================ FILE: tllib/vision/models/reid/loss.py ================================================ """ Modified from https://github.com/yxgeee/MMT @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ import torch import torch.nn as nn import torch.nn.functional as F def pairwise_euclidean_distance(x, y): """Compute pairwise euclidean distance between two sets of features""" m, n = x.size(0), y.size(0) dist_mat = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n) + \ torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t() \ - 2 * torch.matmul(x, y.t()) # for numerical stability dist_mat = dist_mat.clamp(min=1e-12).sqrt() return dist_mat def hard_examples_mining(dist_mat, identity_mat, return_idxes=False): r"""Select hard positives and hard negatives according to `In defense of the Triplet Loss for Person Re-Identification (ICCV 2017) `_ Args: dist_mat (tensor): pairwise distance matrix between two sets of features identity_mat (tensor): a matrix of shape :math:`(N, M)`. If two images :math:`P[i]` of set :math:`P` and :math:`Q[j]` of set :math:`Q` come from the same person, then :math:`identity\_mat[i, j] = 1`, otherwise :math:`identity\_mat[i, j] = 0` return_idxes (bool, optional): if True, also return indexes of hard examples. Default: False """ # the implementation here is a little tricky, dist_mat contains pairwise distance between probe image and other # images in current mini-batch. As we want to select positive examples of the same person, we add a constant # negative offset on other images before sorting. As a result, images of the **same** person will rank first. sorted_dist_mat, sorted_idxes = torch.sort(dist_mat + (-1e7) * (1 - identity_mat), dim=1, descending=True) dist_ap = sorted_dist_mat[:, 0] hard_positive_idxes = sorted_idxes[:, 0] # the implementation here is similar to above code, we add a constant positive offset on images of same person # before sorting. Besides, we sort in ascending order. As a result, images of **different** persons will rank first. sorted_dist_mat, sorted_idxes = torch.sort(dist_mat + 1e7 * identity_mat, dim=1, descending=False) dist_an = sorted_dist_mat[:, 0] hard_negative_idxes = sorted_idxes[:, 0] if return_idxes: return dist_ap, dist_an, hard_positive_idxes, hard_negative_idxes return dist_ap, dist_an class CrossEntropyLossWithLabelSmooth(nn.Module): r"""Cross entropy loss with label smooth from `Rethinking the Inception Architecture for Computer Vision (CVPR 2016) `_. Given one-hot labels :math:`labels \in R^C`, where :math:`C` is the number of classes, smoothed labels are calculated as .. math:: smoothed\_labels = (1 - \epsilon) \times labels + \epsilon \times \frac{1}{C} We use smoothed labels when calculating cross entropy loss and this can be helpful for preventing over-fitting. Args: num_classes (int): number of classes. epsilon (float): a float number that controls the smoothness. Inputs: - y (tensor): unnormalized classifier predictions, :math:`y` - labels (tensor): ground truth labels, :math:`labels` Shape: - y: :math:`(minibatch, C)`, where :math:`C` is the number of classes - labels: :math:`(minibatch, )` """ def __init__(self, num_classes, epsilon=0.1): super(CrossEntropyLossWithLabelSmooth, self).__init__() self.num_classes = num_classes self.epsilon = epsilon self.log_softmax = nn.LogSoftmax(dim=1).cuda() def forward(self, y, labels): log_prob = self.log_softmax(y) labels = torch.zeros_like(log_prob).scatter_(1, labels.unsqueeze(1), 1) labels = (1 - self.epsilon) * labels + self.epsilon / self.num_classes loss = (- labels * log_prob).mean(0).sum() return loss class TripletLoss(nn.Module): """Triplet loss augmented with batch hard from `In defense of the Triplet Loss for Person Re-Identification (ICCV 2017) `_. Args: margin (float): margin of triplet loss normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss. Default: False. """ def __init__(self, margin, normalize_feature=False): super(TripletLoss, self).__init__() self.margin = margin self.normalize_feature = normalize_feature self.margin_loss = nn.MarginRankingLoss(margin=margin).cuda() def forward(self, f, labels): if self.normalize_feature: # equivalent to cosine similarity f = F.normalize(f) dist_mat = pairwise_euclidean_distance(f, f) n = dist_mat.size(0) identity_mat = labels.expand(n, n).eq(labels.expand(n, n).t()).float() dist_ap, dist_an = hard_examples_mining(dist_mat, identity_mat) y = torch.ones_like(dist_ap) loss = self.margin_loss(dist_an, dist_ap, y) return loss class TripletLossXBM(nn.Module): r"""Triplet loss augmented with batch hard from `In defense of the Triplet Loss for Person Re-Identification (ICCV 2017) `_. The only difference from triplet loss lies in that both features from current mini batch and external storage (XBM) are involved. Args: margin (float, optional): margin of triplet loss. Default: 0.3 normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss. Default: False Inputs: - f (tensor): features of current mini batch, :math:`f` - labels (tensor): identity labels for current mini batch, :math:`labels` - xbm_f (tensor): features collected from XBM, :math:`xbm\_f` - xbm_labels (tensor): corresponding identity labels of xbm_f, :math:`xbm\_labels` Shape: - f: :math:`(minibatch, F)`, where :math:`F` is the feature dimension - labels: :math:`(minibatch, )` - xbm_f: :math:`(minibatch, F)` - xbm_labels: :math:`(minibatch, )` """ def __init__(self, margin=0.3, normalize_feature=False): super(TripletLossXBM, self).__init__() self.margin = margin self.normalize_feature = normalize_feature self.ranking_loss = nn.MarginRankingLoss(margin=margin) def forward(self, f, labels, xbm_f, xbm_labels): if self.normalize_feature: # equivalent to cosine similarity f = F.normalize(f) xbm_f = F.normalize(xbm_f) dist_mat = pairwise_euclidean_distance(f, xbm_f) # hard examples mining n, m = f.size(0), xbm_f.size(0) identity_mat = labels.expand(m, n).t().eq(xbm_labels.expand(n, m)).float() dist_ap, dist_an = hard_examples_mining(dist_mat, identity_mat) # Compute ranking hinge loss y = torch.ones_like(dist_an) loss = self.ranking_loss(dist_an, dist_ap, y) return loss class SoftTripletLoss(nn.Module): r"""Soft triplet loss from `Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification (ICLR 2020) `_. Consider a triplet :math:`x,x_p,x_n` (anchor, positive, negative), corresponding features are :math:`f,f_p,f_n`. We optimize for a smaller distance between :math:`f` and :math:`f_p` and a larger distance between :math:`f` and :math:`f_n`. Inner product is adopted as their similarity measure, soft triplet loss is thus defined as .. math:: loss = \mathcal{L}_{\text{bce}}(\frac{\text{exp}(f^Tf_p)}{\text{exp}(f^Tf_p)+\text{exp}(f^Tf_n)}, 1) where :math:`\mathcal{L}_{\text{bce}}` means binary cross entropy loss. We denote the first term in above loss function as :math:`T`. When features from another teacher network can be obtained, we can calculate :math:`T_{teacher}` as labels, resulting in the following soft version .. math:: loss = \mathcal{L}_{\text{bce}}(T, T_{teacher}) Args: margin (float, optional): margin of triplet loss. If None, soft labels from another network will be adopted when computing loss. Default: None. normalize_feature (bool, optional): if True, normalize features into unit norm first before computing loss. Default: False. """ def __init__(self, margin=None, normalize_feature=False): super(SoftTripletLoss, self).__init__() self.margin = margin self.normalize_feature = normalize_feature def forward(self, features_1, features_2, labels): if self.normalize_feature: # equal to cosine similarity features_1 = F.normalize(features_1) features_2 = F.normalize(features_2) dist_mat = pairwise_euclidean_distance(features_1, features_1) assert dist_mat.size(0) == dist_mat.size(1) n = dist_mat.size(0) identity_mat = labels.expand(n, n).eq(labels.expand(n, n).t()).float() dist_ap, dist_an, ap_idxes, an_idxes = hard_examples_mining(dist_mat, identity_mat, return_idxes=True) assert dist_an.size(0) == dist_ap.size(0) triple_dist = torch.stack((dist_ap, dist_an), dim=1) triple_dist = F.log_softmax(triple_dist, dim=1) if self.margin is not None: loss = (- self.margin * triple_dist[:, 0] - (1 - self.margin) * triple_dist[:, 1]).mean() return loss dist_mat_ref = pairwise_euclidean_distance(features_2, features_2) dist_ap_ref = torch.gather(dist_mat_ref, 1, ap_idxes.view(n, 1).expand(n, n))[:, 0] dist_an_ref = torch.gather(dist_mat_ref, 1, an_idxes.view(n, 1).expand(n, n))[:, 0] triple_dist_ref = torch.stack((dist_ap_ref, dist_an_ref), dim=1) triple_dist_ref = F.softmax(triple_dist_ref, dim=1).detach() loss = (- triple_dist_ref * triple_dist).sum(dim=1).mean() return loss class CrossEntropyLoss(nn.Module): r"""We use :math:`C` to denote the number of classes, :math:`N` to denote mini-batch size, this criterion expects unnormalized predictions :math:`y\_{logits}` of shape :math:`(N, C)` and :math:`target\_{logits}` of the same shape :math:`(N, C)`. Then we first normalize them into probability distributions among classes .. math:: y = \text{softmax}(y\_{logits}) .. math:: target = \text{softmax}(target\_{logits}) Final objective is calculated as .. math:: \text{loss} = \frac{1}{N} \sum_{i=1}^{N} \sum_{j=1}^C -target_i^j \times \text{log} (y_i^j) """ def __init__(self): super(CrossEntropyLoss, self).__init__() self.log_softmax = nn.LogSoftmax(dim=1).cuda() def forward(self, y, labels): log_prob = self.log_softmax(y) loss = (- F.softmax(labels, dim=1).detach() * log_prob).sum(dim=1).mean() return loss ================================================ FILE: tllib/vision/models/reid/resnet.py ================================================ """ @author: Baixu Chen @contact: cbx_99_hasta@outlook.com """ from tllib.vision.models.resnet import ResNet, load_state_dict_from_url, model_urls, BasicBlock, Bottleneck __all__ = ['reid_resnet18', 'reid_resnet34', 'reid_resnet50', 'reid_resnet101'] class ReidResNet(ResNet): r"""Modified `ResNet` architecture for ReID from `Mutual Mean-Teaching: Pseudo Label Refinery for Unsupervised Domain Adaptation on Person Re-identification (ICLR 2020) `_. We change stride of :math:`layer4\_group1\_conv2, layer4\_group1\_downsample1` to 1. During forward pass, we will not activate `self.relu`. Please refer to source code for details. """ def __init__(self, *args, **kwargs): super(ReidResNet, self).__init__(*args, **kwargs) self.layer4[0].conv2.stride = (1, 1) self.layer4[0].downsample[0].stride = (1, 1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) # x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x def _reid_resnet(arch, block, layers, pretrained, progress, **kwargs): model = ReidResNet(block, layers, **kwargs) if pretrained: model_dict = model.state_dict() pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress) # remove keys from pretrained dict that doesn't appear in model dict pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model.load_state_dict(pretrained_dict, strict=False) return model def reid_resnet18(pretrained=False, progress=True, **kwargs): r"""Constructs a Reid-ResNet-18 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _reid_resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def reid_resnet34(pretrained=False, progress=True, **kwargs): r"""Constructs a Reid-ResNet-34 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _reid_resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def reid_resnet50(pretrained=False, progress=True, **kwargs): r"""Constructs a Reid-ResNet-50 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _reid_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def reid_resnet101(pretrained=False, progress=True, **kwargs): r"""Constructs a Reid-ResNet-101 model. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _reid_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) ================================================ FILE: tllib/vision/models/resnet.py ================================================ """ Modified based on torchvision.models.resnet. @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn from torchvision import models from torch.hub import load_state_dict_from_url from torchvision.models.resnet import BasicBlock, Bottleneck, model_urls import copy __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'wide_resnet50_2', 'wide_resnet101_2'] class ResNet(models.ResNet): """ResNets without fully connected layer""" def __init__(self, *args, **kwargs): super(ResNet, self).__init__(*args, **kwargs) self._out_features = self.fc.in_features def forward(self, x): """""" x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # x = self.avgpool(x) # x = torch.flatten(x, 1) # x = x.view(-1, self._out_features) return x @property def out_features(self) -> int: """The dimension of output features""" return self._out_features def copy_head(self) -> nn.Module: """Copy the origin fully connected layer""" return copy.deepcopy(self.fc) def _resnet(arch, block, layers, pretrained, progress, **kwargs): model = ResNet(block, layers, **kwargs) if pretrained: model_dict = model.state_dict() pretrained_dict = load_state_dict_from_url(model_urls[arch], progress=progress) # remove keys from pretrained dict that doesn't appear in model dict pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model.load_state_dict(pretrained_dict, strict=False) return model def resnet18(pretrained=False, progress=True, **kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs) def resnet34(pretrained=False, progress=True, **kwargs): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet50(pretrained=False, progress=True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnet101(pretrained=False, progress=True, **kwargs): r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def resnet152(pretrained=False, progress=True, **kwargs): r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs) def resnext50_32x4d(pretrained=False, progress=True, **kwargs): r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['groups'] = 32 kwargs['width_per_group'] = 4 return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def resnext101_32x8d(pretrained=False, progress=True, **kwargs): r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['groups'] = 32 kwargs['width_per_group'] = 8 return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) def wide_resnet50_2(pretrained=False, progress=True, **kwargs): r"""Wide ResNet-50-2 model from `"Wide Residual Networks" `_ The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['width_per_group'] = 64 * 2 return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs) def wide_resnet101_2(pretrained=False, progress=True, **kwargs): r"""Wide ResNet-101-2 model from `"Wide Residual Networks" `_ The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ kwargs['width_per_group'] = 64 * 2 return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs) ================================================ FILE: tllib/vision/models/segmentation/__init__.py ================================================ from .deeplabv2 import * __all__ = ['deeplabv2'] ================================================ FILE: tllib/vision/models/segmentation/deeplabv2.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ import torch.nn as nn from torchvision.models.utils import load_state_dict_from_url model_urls = { 'deeplabv2_resnet101': 'https://cloud.tsinghua.edu.cn/f/2d9a7fc43ce34f76803a/?dl=1' } affine_par = True class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) for i in self.bn1.parameters(): i.requires_grad = False padding = dilation self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change padding=padding, bias=False, dilation=dilation) self.bn2 = nn.BatchNorm2d(planes, affine=affine_par) for i in self.bn2.parameters(): i.requires_grad = False self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par) for i in self.bn3.parameters(): i.requires_grad = False self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class ASPP_V2(nn.Module): def __init__(self, inplanes, dilation_series, padding_series, num_classes): super(ASPP_V2, self).__init__() self.conv2d_list = nn.ModuleList() for dilation, padding in zip(dilation_series, padding_series): self.conv2d_list.append( nn.Conv2d(inplanes, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True)) for m in self.conv2d_list: m.weight.data.normal_(0, 0.01) def forward(self, x): out = self.conv2d_list[0](x) for i in range(len(self.conv2d_list) - 1): out += self.conv2d_list[i + 1](x) return out class ResNet(nn.Module): def __init__(self, block, layers): self.inplanes = 64 super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64, affine=affine_par) for i in self.bn1.parameters(): i.requires_grad = False self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.01) elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def _make_layer(self, block, planes, blocks, stride=1, dilation=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: downsample = nn.Sequential( nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion, affine=affine_par)) for i in downsample._modules['1'].parameters(): i.requires_grad = False layers = [] layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, dilation=dilation)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x class Deeplab(nn.Module): def __init__(self, backbone, classifier, num_classes): super(Deeplab, self).__init__() self.backbone = backbone self.classifier = classifier self.num_classes = num_classes def forward(self, x): x = self.backbone(x) y = self.classifier(x) return y def get_1x_lr_params_NOscale(self): """ This generator returns all the parameters of the net except for the last classification layer. Note that for each batchnorm layer, requires_grad is set to False in deeplab_resnet.py, therefore this function does not return any batchnorm parameter """ layers = [self.backbone.conv1, self.backbone.bn1, self.backbone.layer1, self.backbone.layer2, self.backbone.layer3, self.backbone.layer4] for layer in layers: for module in layer.modules(): for param in module.parameters(): if param.requires_grad: yield param def get_10x_lr_params(self): """ This generator returns all the parameters for the last layer of the net, which does the classification of pixel into classes """ for param in self.classifier.parameters(): yield param def get_parameters(self, lr=1.): return [ {'params': self.get_1x_lr_params_NOscale(), 'lr': 0.1 * lr}, {'params': self.get_10x_lr_params(), 'lr': lr} ] def deeplabv2_resnet101(num_classes=19, pretrained_backbone=True): """Constructs a DeepLabV2 model with a ResNet-101 backbone. Args: num_classes (int, optional): number of classes. Default: 19 pretrained_backbone (bool, optional): If True, returns a model pre-trained on ImageNet. Default: True. """ backbone = ResNet(Bottleneck, [3, 4, 23, 3]) if pretrained_backbone: # download from Internet saved_state_dict = load_state_dict_from_url(model_urls['deeplabv2_resnet101'], map_location=lambda storage, loc: storage, file_name="deeplabv2_resnet101.pth") new_params = backbone.state_dict().copy() for i in saved_state_dict: i_parts = i.split('.') if not i_parts[1] == 'layer5': new_params['.'.join(i_parts[1:])] = saved_state_dict[i] backbone.load_state_dict(new_params) classifier = ASPP_V2(2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes) return Deeplab(backbone, classifier, num_classes) ================================================ FILE: tllib/vision/transforms/__init__.py ================================================ import math import random from PIL import Image import numpy as np import torch from torchvision.transforms import Normalize class ResizeImage(object): """Resize the input PIL Image to the given size. Args: size (sequence or int): Desired output size. If size is a sequence like (h, w), output size will be matched to this. If size is an int, output size will be (size, size) """ def __init__(self, size): if isinstance(size, int): self.size = (int(size), int(size)) else: self.size = size def __call__(self, img): th, tw = self.size return img.resize((th, tw)) def __repr__(self): return self.__class__.__name__ + '(size={0})'.format(self.size) class MultipleApply: """Apply a list of transformations to an image and get multiple transformed images. Args: transforms (list or tuple): list of transformations Example: >>> transform1 = T.Compose([ ... ResizeImage(256), ... T.RandomCrop(224) ... ]) >>> transform2 = T.Compose([ ... ResizeImage(256), ... T.RandomCrop(224), ... ]) >>> multiply_transform = MultipleApply([transform1, transform2]) """ def __init__(self, transforms): self.transforms = transforms def __call__(self, image): return [t(image) for t in self.transforms] def __repr__(self): format_string = self.__class__.__name__ + '(' for t in self.transforms: format_string += '\n' format_string += ' {0}'.format(t) format_string += '\n)' return format_string class Denormalize(Normalize): """DeNormalize a tensor image with mean and standard deviation. Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n`` channels, this transform will denormalize each channel of the input ``torch.*Tensor`` i.e., ``output[channel] = input[channel] * std[channel] + mean[channel]`` .. note:: This transform acts out of place, i.e., it does not mutate the input tensor. Args: mean (sequence): Sequence of means for each channel. std (sequence): Sequence of standard deviations for each channel. """ def __init__(self, mean, std): mean = np.array(mean) std = np.array(std) super().__init__((-mean / std).tolist(), (1 / std).tolist()) class NormalizeAndTranspose: """ First, normalize a tensor image with mean and standard deviation. Then, convert the shape (H x W x C) to shape (C x H x W). """ def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)): self.mean = np.array(mean, dtype=np.float32) def __call__(self, image): if isinstance(image, Image.Image): image = np.asarray(image, np.float32) # change to BGR image = image[:, :, ::-1] # normalize image -= self.mean image = image.transpose((2, 0, 1)).copy() elif isinstance(image, torch.Tensor): # change to BGR image = image[:, :, [2, 1, 0]] # normalize image -= torch.from_numpy(self.mean).to(image.device) image = image.permute((2, 0, 1)) else: raise NotImplementedError(type(image)) return image class DeNormalizeAndTranspose: """ First, convert a tensor image from the shape (C x H x W ) to shape (H x W x C). Then, denormalize it with mean and standard deviation. """ def __init__(self, mean=(104.00698793, 116.66876762, 122.67891434)): self.mean = np.array(mean, dtype=np.float32) def __call__(self, image): image = image.transpose((1, 2, 0)) # denormalize image += self.mean # change to RGB image = image[:, :, ::-1] return image class RandomErasing(object): """Random erasing augmentation from `Random Erasing Data Augmentation (CVPR 2017) `_. This augmentation randomly selects a rectangle region in an image and erases its pixels. Args: probability (float): The probability that the Random Erasing operation will be performed. sl (float): Minimum proportion of erased area against input image. sh (float): Maximum proportion of erased area against input image. r1 (float): Minimum aspect ratio of erased area. mean (sequence): Value to fill the erased area. """ def __init__(self, probability=0.5, sl=0.02, sh=0.4, r1=0.3, mean=(0.4914, 0.4822, 0.4465)): self.probability = probability self.mean = mean self.sl = sl self.sh = sh self.r1 = r1 def __call__(self, img): if random.uniform(0, 1) >= self.probability: return img for attempt in range(100): area = img.size()[1] * img.size()[2] target_area = random.uniform(self.sl, self.sh) * area aspect_ratio = random.uniform(self.r1, 1 / self.r1) h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) if w < img.size()[2] and h < img.size()[1]: x1 = random.randint(0, img.size()[1] - h) y1 = random.randint(0, img.size()[2] - w) if img.size()[0] == 3: img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] img[1, x1:x1 + h, y1:y1 + w] = self.mean[1] img[2, x1:x1 + h, y1:y1 + w] = self.mean[2] else: img[0, x1:x1 + h, y1:y1 + w] = self.mean[0] return img return img def __repr__(self): return self.__class__.__name__ + '(p={})'.format(self.probability) ================================================ FILE: tllib/vision/transforms/keypoint_detection.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ # TODO needs better documentation import numpy as np from PIL import ImageFilter, Image import torchvision.transforms.functional as F import torchvision.transforms.transforms as T import numbers import random import math import warnings from typing import ClassVar def wrapper(transform: ClassVar): """ Wrap a transform for classification to a transform for keypoint detection. Note that the keypoint detection label will keep the same before and after wrapper. Args: transform (class, callable): transform for classification Returns: transform for keypoint detection """ class WrapperTransform(transform): def __call__(self, image, **kwargs): image = super().__call__(image) return image, kwargs return WrapperTransform ToTensor = wrapper(T.ToTensor) Normalize = wrapper(T.Normalize) ColorJitter = wrapper(T.ColorJitter) def resize(image: Image.Image, size: int, interpolation=Image.BILINEAR, keypoint2d: np.ndarray=None, intrinsic_matrix: np.ndarray=None): width, height = image.size assert width == height factor = float(size) / float(width) image = F.resize(image, size, interpolation) keypoint2d = np.copy(keypoint2d) keypoint2d *= factor intrinsic_matrix = np.copy(intrinsic_matrix) intrinsic_matrix[0][0] *= factor intrinsic_matrix[0][2] *= factor intrinsic_matrix[1][1] *= factor intrinsic_matrix[1][2] *= factor return image, keypoint2d, intrinsic_matrix def crop(image: Image.Image, top, left, height, width, keypoint2d: np.ndarray): image = F.crop(image, top, left, height, width) keypoint2d = np.copy(keypoint2d) keypoint2d[:, 0] -= left keypoint2d[:, 1] -= top return image, keypoint2d def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR, keypoint2d: np.ndarray=None, intrinsic_matrix: np.ndarray=None): """Crop the given PIL Image and resize it to desired size. Notably used in :class:`~torchvision.transforms.RandomResizedCrop`. Args: img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. top (int): Vertical component of the top left corner of the crop box. left (int): Horizontal component of the top left corner of the crop box. height (int): Height of the crop box. width (int): Width of the crop box. size (sequence or int): Desired output size. Same semantics as ``resize``. interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``. Returns: PIL Image: Cropped image. """ assert isinstance(img, Image.Image), 'img should be PIL Image' img, keypoint2d = crop(img, top, left, height, width, keypoint2d) img, keypoint2d, intrinsic_matrix = resize(img, size, interpolation, keypoint2d, intrinsic_matrix) return img, keypoint2d, intrinsic_matrix def center_crop(image, output_size, keypoint2d: np.ndarray): """Crop the given PIL Image and resize it to desired size. Args: img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image. output_size (sequence or int): (height, width) of the crop box. If int, it is used for both directions Returns: PIL Image: Cropped image. """ width, height = image.size crop_height, crop_width = output_size crop_top = int(round((height - crop_height) / 2.)) crop_left = int(round((width - crop_width) / 2.)) return crop(image, crop_top, crop_left, crop_height, crop_width, keypoint2d) def hflip(image: Image.Image, keypoint2d: np.ndarray): width, height = image.size image = F.hflip(image) keypoint2d = np.copy(keypoint2d) keypoint2d[:, 0] = width - 1. - keypoint2d[:, 0] return image, keypoint2d def rotate(image: Image.Image, angle, keypoint2d: np.ndarray): image = F.rotate(image, angle) angle = -np.deg2rad(angle) keypoint2d = np.copy(keypoint2d) rotation_matrix = np.array([ [np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)] ]) width, height = image.size keypoint2d[:, 0] = keypoint2d[:, 0] - width / 2 keypoint2d[:, 1] = keypoint2d[:, 1] - height / 2 keypoint2d = np.matmul(rotation_matrix, keypoint2d.T).T keypoint2d[:, 0] = keypoint2d[:, 0] + width / 2 keypoint2d[:, 1] = keypoint2d[:, 1] + height / 2 return image, keypoint2d def resize_pad(img, keypoint2d, size, interpolation=Image.BILINEAR): w, h = img.size if w < h: oh = size ow = int(size * w / h) img = img.resize((ow, oh), interpolation) pad_top = pad_bottom = 0 pad_left = math.floor((size - ow) / 2) pad_right = math.ceil((size - ow) / 2) keypoint2d = keypoint2d * oh / h keypoint2d[:, 0] += (size - ow) / 2 else: ow = size oh = int(size * h / w) img = img.resize((ow, oh), interpolation) pad_top = math.floor((size - oh) / 2) pad_bottom = math.ceil((size - oh) / 2) pad_left = pad_right = 0 keypoint2d = keypoint2d * ow / w keypoint2d[:, 1] += (size - oh) / 2 keypoint2d[:, 0] += (size - ow) / 2 img = np.asarray(img) img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), 'constant', constant_values=0) return Image.fromarray(img), keypoint2d class Compose(object): """Composes several transforms together. Args: transforms (list of ``Transform`` objects): list of transforms to compose. """ def __init__(self, transforms): self.transforms = transforms def __call__(self, image, **kwargs): for t in self.transforms: image, kwargs = t(image, **kwargs) return image, kwargs class GaussianBlur(object): def __init__(self, low=0, high=0.8): self.low = low self.high = high def __call__(self, image: Image, **kwargs): radius = np.random.uniform(low=self.low, high=self.high) image = image.filter(ImageFilter.GaussianBlur(radius)) return image, kwargs class Resize(object): """Resize the input PIL Image to the given size. """ def __init__(self, size, interpolation=Image.BILINEAR): assert isinstance(size, int) self.size = size self.interpolation = interpolation def __call__(self, image, keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, **kwargs): image, keypoint2d, intrinsic_matrix = resize(image, self.size, self.interpolation, keypoint2d, intrinsic_matrix) kwargs.update(keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix) if 'depth' in kwargs: kwargs['depth'] = F.resize(kwargs['depth'], self.size) return image, kwargs class ResizePad(object): """Pad the given image on all sides with the given "pad" value to resize the image to the given size. """ def __init__(self, size, interpolation=Image.BILINEAR): self.size = size self.interpolation = interpolation def __call__(self, img, keypoint2d, **kwargs): image, keypoint2d = resize_pad(img, keypoint2d, self.size, self.interpolation) kwargs.update(keypoint2d=keypoint2d) return image, kwargs class CenterCrop(object): """Crops the given PIL Image at the center. """ def __init__(self, size): if isinstance(size, numbers.Number): self.size = (int(size), int(size)) else: self.size = size def __call__(self, image, keypoint2d, **kwargs): """ Args: img (PIL Image): Image to be cropped. Returns: PIL Image: Cropped image. """ image, keypoint2d = center_crop(image, self.size, keypoint2d) kwargs.update(keypoint2d=keypoint2d) if 'depth' in kwargs: kwargs['depth'] = F.center_crop(kwargs['depth'], self.size) return image, kwargs class RandomRotation(object): """Rotate the image by angle. Args: degrees (sequence or float or int): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). """ def __init__(self, degrees): if isinstance(degrees, numbers.Number): if degrees < 0: raise ValueError("If degrees is a single number, it must be positive.") self.degrees = (-degrees, degrees) else: if len(degrees) != 2: raise ValueError("If degrees is a sequence, it must be of len 2.") self.degrees = degrees @staticmethod def get_params(degrees): """Get parameters for ``rotate`` for a random rotation. Returns: sequence: params to be passed to ``rotate`` for random rotation. """ angle = random.uniform(degrees[0], degrees[1]) return angle def __call__(self, image, keypoint2d, **kwargs): """ Args: img (PIL Image): Image to be rotated. Returns: PIL Image: Rotated image. """ angle = self.get_params(self.degrees) image, keypoint2d = rotate(image, angle, keypoint2d) kwargs.update(keypoint2d=keypoint2d) if 'depth' in kwargs: kwargs['depth'] = F.rotate(kwargs['depth'], angle) return image, kwargs class RandomResizedCrop(object): """Crop the given PIL Image to random size and aspect ratio. A crop of random size (default: of 0.08 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. This is popularly used to train the Inception networks. Args: size: expected output size of each edge scale: range of size of the origin size cropped ratio: range of aspect ratio of the origin aspect ratio cropped interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, scale=(0.6, 1.3), interpolation=Image.BILINEAR): self.size = size if scale[0] > scale[1]: warnings.warn("range should be of kind (min, max)") self.interpolation = interpolation self.scale = scale @staticmethod def get_params(img, scale): """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image): Image to be cropped. scale (tuple): range of size of the origin size cropped Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ width, height = img.size area = height * width for attempt in range(10): target_area = random.uniform(*scale) * area aspect_ratio = 1 w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: i = random.randint(0, height - h) j = random.randint(0, width - w) return i, j, h, w # Fallback to whole image return 0, 0, height, width def __call__(self, image, keypoint2d: np.ndarray, intrinsic_matrix: np.ndarray, **kwargs): """ Args: img (PIL Image): Image to be cropped and resized. Returns: PIL Image: Randomly cropped and resized image. """ i, j, h, w = self.get_params(image, self.scale) image, keypoint2d, intrinsic_matrix = resized_crop(image, i, j, h, w, self.size, self.interpolation, keypoint2d, intrinsic_matrix) kwargs.update(keypoint2d=keypoint2d, intrinsic_matrix=intrinsic_matrix) if 'depth' in kwargs: kwargs['depth'] = F.resized_crop(kwargs['depth'], i, j, h, w, self.size, self.interpolation,) return image, kwargs class RandomApply(T.RandomTransforms): """Apply randomly a list of transformations with a given probability. Args: transforms (list or tuple or torch.nn.Module): list of transformations p (float): probability """ def __init__(self, transforms, p=0.5): super(RandomApply, self).__init__(transforms) self.p = p def __call__(self, image, **kwargs): if self.p < random.random(): return image, kwargs for t in self.transforms: image, kwargs = t(image, **kwargs) return image, kwargs ================================================ FILE: tllib/vision/transforms/segmentation.py ================================================ """ @author: Junguang Jiang @contact: JiangJunguang1123@outlook.com """ from PIL import Image import random import math from typing import ClassVar, Sequence, List, Tuple from torch import Tensor import torch import torchvision.transforms.functional as F import torchvision.transforms.transforms as T import torch.nn as nn from . import MultipleApply as MultipleApplyBase, NormalizeAndTranspose as NormalizeAndTransposeBase def wrapper(transform: ClassVar): """ Wrap a transform for classification to a transform for segmentation. Note that the segmentation label will keep the same before and after wrapper. Args: transform (class, callable): transform for classification Returns: transform for segmentation """ class WrapperTransform(transform): def __call__(self, image, label): image = super().__call__(image) return image, label return WrapperTransform ColorJitter = wrapper(T.ColorJitter) Normalize = wrapper(T.Normalize) ToTensor = wrapper(T.ToTensor) ToPILImage = wrapper(T.ToPILImage) MultipleApply = wrapper(MultipleApplyBase) NormalizeAndTranspose = wrapper(NormalizeAndTransposeBase) class Compose: """Composes several transforms together. Args: transforms (list): list of transforms to compose. Example: >>> Compose([ >>> Resize((512, 512)), >>> RandomHorizontalFlip() >>> ]) """ def __init__(self, transforms): super(Compose, self).__init__() self.transforms = transforms def __call__(self, image, target): for t in self.transforms: image, target = t(image, target) return image, target class Resize(nn.Module): """Resize the input image and the corresponding label to the given size. The image should be a PIL Image. Args: image_size (sequence): The requested image size in pixels, as a 2-tuple: (width, height). label_size (sequence, optional): The requested segmentation label size in pixels, as a 2-tuple: (width, height). The same as image_size if None. Default: None. """ def __init__(self, image_size, label_size=None): super(Resize, self).__init__() self.image_size = image_size if label_size is None: self.label_size = image_size else: self.label_size = label_size def forward(self, image, label): """ Args: image: (PIL Image): Image to be scaled. label: (PIL Image): Segmentation label to be scaled. Returns: Rescaled image, rescaled segmentation label """ # resize image = image.resize(self.image_size, Image.BICUBIC) label = label.resize(self.label_size, Image.NEAREST) return image, label class RandomCrop(nn.Module): """Crop the given image at a random location. The image can be a PIL Image Args: size (sequence): Desired output size of the crop. """ def __init__(self, size): super(RandomCrop, self).__init__() self.size = size def forward(self, image, label): """ Args: image: (PIL Image): Image to be cropped. label: (PIL Image): Segmentation label to be cropped. Returns: Cropped image, cropped segmentation label. """ # random crop left = image.size[0] - self.size[0] upper = image.size[1] - self.size[1] left = random.randint(0, left-1) upper = random.randint(0, upper-1) right = left + self.size[0] lower = upper + self.size[1] image = image.crop((left, upper, right, lower)) label = label.crop((left, upper, right, lower)) return image, label class RandomHorizontalFlip(nn.Module): """Horizontally flip the given PIL Image randomly with a given probability. Args: p (float): probability of the image being flipped. Default value is 0.5 """ def __init__(self, p=0.5): super(RandomHorizontalFlip, self).__init__() self.p = p def forward(self, image, label): """ Args: image: (PIL Image): Image to be flipped. label: (PIL Image): Segmentation label to be flipped. Returns: Randomly flipped image, randomly flipped segmentation label. """ if random.random() < self.p: return F.hflip(image), F.hflip(label) return image, label class RandomResizedCrop(T.RandomResizedCrop): """Crop the given image to random size and aspect ratio. The image can be a PIL Image. A crop of random size (default: of 0.5 to 1.0) of the original size and a random aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop is finally resized to given size. Args: size (int or sequence): expected output size of each edge. If size is an int instead of sequence like (h, w), a square output size ``(size, size)`` is made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). scale (tuple of float): range of size of the origin size cropped ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped. interpolation: Default: PIL.Image.BILINEAR """ def __init__(self, size, scale=(0.5, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BICUBIC): super(RandomResizedCrop, self).__init__(size, scale, ratio, interpolation) @staticmethod def get_params( img: Tensor, scale: List[float], ratio: List[float] ) -> Tuple[int, int, int, int]: """Get parameters for ``crop`` for a random sized crop. Args: img (PIL Image): Input image. scale (list): range of scale of the origin size cropped ratio (list): range of aspect ratio of the origin aspect ratio cropped Returns: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ width, height = F._get_image_size(img) area = height * width for _ in range(10): target_area = area * random.uniform(scale[0], scale[1]) log_ratio = torch.log(torch.tensor(ratio)) aspect_ratio = math.exp(random.uniform(log_ratio[0], log_ratio[1])) w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) if 0 < w <= width and 0 < h <= height: i = random.randint(0, height - h) j = random.randint(0, width - w) return i, j, h, w # Fallback to central crop in_ratio = float(width) / float(height) if in_ratio < min(ratio): w = width h = int(round(w / min(ratio))) elif in_ratio > max(ratio): h = height w = int(round(h * max(ratio))) else: # whole image w = width h = height i = (height - h) // 2 j = (width - w) // 2 return i, j, h, w def forward(self, image, label): """ Args: image: (PIL Image): Image to be cropped and resized. label: (PIL Image): Segmentation label to be cropped and resized. Returns: Randomly cropped and resized image, randomly cropped and resized segmentation label. """ top, left, height, width = self.get_params(image, self.scale, self.ratio) image = image.crop((left, top, left + width, top + height)) image = image.resize(self.size, self.interpolation) label = label.crop((left, top, left + width, top + height)) label = label.resize(self.size, Image.NEAREST) return image, label class RandomChoice(T.RandomTransforms): """Apply single transformation randomly picked from a list. """ def __call__(self, image, label): t = random.choice(self.transforms) return t(image, label) class RandomApply(T.RandomTransforms): """Apply randomly a list of transformations with a given probability. Args: transforms (list or tuple or torch.nn.Module): list of transformations p (float): probability """ def __init__(self, transforms, p=0.5): super(RandomApply, self).__init__(transforms) self.p = p def __call__(self, image, label): if self.p < random.random(): return image for t in self.transforms: image, label = t(image, label) return image